. + long_answer_html_tag = doc_tokens[long_answer["start_token"]]["token"] + if long_answer_html_tag != "
": + global n_non_p + n_non_p += 1 + return + answer = clean_text( + short_answer["start_token"], short_answer["end_token"], doc_tokens, + doc_bytes) + before_answer = clean_text( + 0, short_answer["start_token"], doc_tokens, + doc_bytes, ignore_final_whitespace=False) + + elif anno_type == "no_answer": + answer = "" + before_answer = "" + + # Throw out long answer annotations + elif anno_type == "long_answer": + global n_long_ans + n_long_ans += 1 + continue + + anno_types.append(anno_type) + answer = {"answer_start": len(before_answer), + "text": answer} + answers.append(answer) + + if len(answers) == 0: + global n_long_ans_only + n_long_ans_only += 1 + return + + answers, is_impossible = reduce_annotations(anno_types, answers) + + paragraph = clean_text( + 0, len(doc_tokens), doc_tokens, + doc_bytes) + + return {"title": record["document_title"], + "paragraphs": + [{"context": paragraph, + "qas": [{"answers": answers, + "id": record["example_id"], + "question": question_text, + "is_impossible": is_impossible}]}]} + +def main(): + parser = argparse.ArgumentParser( + description="Convert the Natural Questions to SQuAD JSON format.") + parser.add_argument("--data_pattern", dest="data_pattern", + help=("A file pattern to match the Natural Questions " + "dataset."), + metavar="PATTERN", required=True) + parser.add_argument("--version", dest="version", + help="The version label in the output file.", + metavar="LABEL", default="nq-train") + parser.add_argument("--output_file", dest="output_file", + help="The name of the SQuAD JSON formatted output file.", + metavar="FILE", default="nq_as_squad.json") + args = parser.parse_args() + + root = logging.getLogger() + root.setLevel(logging.DEBUG) + + records = 0 + nq_as_squad = {"version": args.version, "data": []} + + for file in sorted(glob.iglob(args.data_pattern)): + logging.info("opening %s", file) + with gzip.GzipFile(file, "r") as f: + for line in f: + records += 1 + nq_record = json.loads(line) + try: + squad_record = nq_to_squad(nq_record) + except: + squad_record = None + global n_error + n_error += 1 + if squad_record: + nq_as_squad["data"].append(squad_record) + if records % 100 == 0: + logging.info("processed %s records", records) + print("Converted %s NQ records into %s SQuAD records." % + (records, len(nq_as_squad["data"]))) + print(f"Removed samples: yes/no: {n_yn} multi_short: {n_ms} non_para {n_non_p} long_ans_only: {n_long_ans_only} errors: {n_error}") + print(f"Removed annotations: long_answer: {n_long_ans} short_answer: {n_short} no_answer: ~{n_no_ans}") + + with open(args.output_file, "w") as f: + json.dump(nq_as_squad, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/test/benchmarks/reader.py b/test/benchmarks/reader.py new file mode 100644 index 000000000..534556879 --- /dev/null +++ b/test/benchmarks/reader.py @@ -0,0 +1,54 @@ +from utils import get_document_store, index_to_doc_store, get_reader +from haystack.preprocessor.utils import eval_data_from_file +from pathlib import Path +import pandas as pd + +reader_models = ["deepset/roberta-base-squad2", "deepset/minilm-uncased-squad2", + "deepset/bert-base-cased-squad2", "deepset/bert-large-uncased-whole-word-masking-squad2", + "deepset/xlm-roberta-large-squad2", "distilbert-base-uncased-distilled-squad"] + +reader_types = ["farm"] +data_dir = Path("../../data/squad20") +filename = "dev-v2.0.json" +# Note that this number is approximate - it was calculated using Bert Base Cased +# This number could vary when using a different tokenizer +n_passages = 12350 + +doc_index = "eval_document" +label_index = "label" + +def benchmark_reader(): + reader_results = [] + doc_store = get_document_store("elasticsearch") + docs, labels = eval_data_from_file(data_dir/filename) + index_to_doc_store(doc_store, docs, None, labels) + for reader_name in reader_models: + for reader_type in reader_types: + try: + reader = get_reader(reader_name, reader_type) + results = reader.eval(document_store=doc_store, + doc_index=doc_index, + label_index=label_index, + device="cuda") + # print(results) + results["passages_per_second"] = n_passages / results["reader_time"] + results["reader"] = reader_name + results["error"] = "" + reader_results.append(results) + except Exception as e: + results = {'EM': 0., + 'f1': 0., + 'top_n_accuracy': 0., + 'top_n': 0, + 'reader_time': 0., + "passages_per_second": 0., + "seconds_per_query": 0., + 'reader': reader_name, + "error": e} + reader_results.append(results) + reader_df = pd.DataFrame.from_records(reader_results) + reader_df.to_csv("reader_results.csv") + + +if __name__ == "__main__": + benchmark_reader() \ No newline at end of file diff --git a/test/benchmarks/reader_results.csv b/test/benchmarks/reader_results.csv new file mode 100644 index 000000000..5fc081050 --- /dev/null +++ b/test/benchmarks/reader_results.csv @@ -0,0 +1,6 @@ +,EM,f1,top_n_accuracy,top_n,reader_time,seconds_per_query,passages_per_second,reader,error +0,0.7589752233271532,0.8067985794671885,0.9671329849991572,5,133.79706027999998,0.011275666634080564,92.30397120949361,deepset/roberta-base-squad2, +1,0.7359683128265633,0.7823306265318686,0.9714309792684982,5,125.22323393199997,0.010553112584864317,98.62387044489225,deepset/minilm-uncased-squad2, +2,0.700825889094893,0.7490271600053505,0.9585369964604753,5,123.58959278499992,0.010415438461570867,99.92750782409666,deepset/bert-base-cased-squad2, +3,0.7821506826226192,0.8264545708097472,0.9762346199224675,5,312.42233685099995,0.026329204184308102,39.529824033964466,deepset/bert-large-uncased-whole-word-masking-squad2, +4,0.8099612337771785,0.8526275190954586,0.9772459126917242,5,314.3179854819998,0.026488958830439897,39.29142006004379,deepset/xlm-roberta-large-squad2, \ No newline at end of file diff --git a/test/benchmarks/results_to_json.py b/test/benchmarks/results_to_json.py new file mode 100644 index 000000000..3acd7c5cf --- /dev/null +++ b/test/benchmarks/results_to_json.py @@ -0,0 +1,101 @@ +import json +import pandas as pd +from pprint import pprint + +def reader(): + + model_rename_map = { + 'deepset/roberta-base-squad2': "RoBERTa", + 'deepset/minilm-uncased-squad2': "MiniLM", + 'deepset/bert-base-cased-squad2': "BERT base", + 'deepset/bert-large-uncased-whole-word-masking-squad2': "BERT large", + 'deepset/xlm-roberta-large-squad2': "XLM-RoBERTa", + } + + column_name_map = { + "f1": "F1", + "passages_per_second": "Speed", + "reader": "Model" + } + + df = pd.read_csv("reader_results.csv") + df = df[["f1", "passages_per_second", "reader"]] + df["reader"] = df["reader"].map(model_rename_map) + df = df[list(column_name_map)] + df = df.rename(columns=column_name_map) + ret = [dict(row) for i, row in df.iterrows()] + print("Reader overview") + print(json.dumps(ret, indent=2)) + +def retriever(): + + + column_name_map = { + "model": "model", + "n_docs": "n_docs", + "docs_per_second": "index_speed", + "queries_per_second": "query_speed", + "map": "map" + } + + name_cleaning = { + "dpr": "DPR", + "elastic": "BM25", + "elasticsearch": "ElasticSearch", + "faiss": "FAISS", + "faiss_flat": "FAISS (flat)", + "faiss_hnsw": "FAISS (HSNW)" + } + + index = pd.read_csv("retriever_index_results.csv") + query = pd.read_csv("retriever_query_results.csv") + df = pd.merge(index, query, + how="right", + left_on=["retriever", "doc_store", "n_docs"], + right_on=["retriever", "doc_store", "n_docs"]) + + df["retriever"] = df["retriever"].map(name_cleaning) + df["doc_store"] = df["doc_store"].map(name_cleaning) + df["model"] = df["retriever"] + " / " + df["doc_store"] + + df = df[list(column_name_map)] + df = df.rename(columns=column_name_map) + + print("Retriever overview") + print(retriever_overview(df)) + + print("Retriever MAP") + print(retriever_map(df)) + + print("Retriever Speed") + print(retriever_speed(df)) + + +def retriever_map(df): + columns = ["model", "n_docs", "map"] + df = df[columns] + ret = [list(row) for i, row in df.iterrows()] + ret = [columns] + ret + return json.dumps(ret, indent=4) + + +def retriever_speed(df): + columns = ["model", "n_docs", "query_speed"] + df = df[columns] + ret = [list(row) for i, row in df.iterrows()] + ret = [columns] + ret + return json.dumps(ret, indent=4) + + + +def retriever_overview(df, chosen_n_docs=100_000): + + df = df[df["n_docs"] == chosen_n_docs] + ret = [dict(row) for i, row in df.iterrows()] + + return json.dumps(ret, indent=2) + + +if __name__ == "__main__": + reader() + retriever() \ No newline at end of file diff --git a/test/benchmarks/retriever.py b/test/benchmarks/retriever.py new file mode 100644 index 000000000..d3c637e4e --- /dev/null +++ b/test/benchmarks/retriever.py @@ -0,0 +1,224 @@ +import pandas as pd +from pathlib import Path +from time import perf_counter +from utils import get_document_store, get_retriever, index_to_doc_store +from haystack.preprocessor.utils import eval_data_from_file +from haystack import Document +import pickle +import time +from tqdm import tqdm +import logging +import datetime +import random +import traceback + + +logger = logging.getLogger(__name__) +logging.getLogger("haystack.retriever.base").setLevel(logging.WARN) +logging.getLogger("elasticsearch").setLevel(logging.WARN) + +es_similarity = "dot_product" + +retriever_doc_stores = [ + # ("elastic", "elasticsearch"), + # ("dpr", "elasticsearch"), + # ("dpr", "faiss_flat"), + ("dpr", "faiss_hnsw") +] + +n_docs_options = [ + 1000, + 10000, + 100000, + 500000, +] + +# If set to None, querying will be run on all queries +n_queries = None +data_dir = Path("../../data/retriever") +filename_gold = "nq2squad-dev.json" # Found at s3://ext-haystack-retriever-eval +filename_negative = "psgs_w100_minus_gold.tsv" # Found at s3://ext-haystack-retriever-eval +embeddings_dir = Path("embeddings") +embeddings_filenames = [f"wikipedia_passages_1m.pkl"] # Found at s3://ext-haystack-retriever-eval + +doc_index = "eval_document" +label_index = "label" + +seed = 42 + +random.seed(42) + + +def prepare_data(data_dir, filename_gold, filename_negative, n_docs=None, n_queries=None, add_precomputed=False): + """ + filename_gold points to a squad format file. + filename_negative points to a csv file where the first column is doc_id and second is document text. + If add_precomputed is True, this fn will look in the embeddings files for precomputed embeddings to add to each Document + """ + + gold_docs, labels = eval_data_from_file(data_dir / filename_gold) + + # Reduce number of docs + gold_docs = gold_docs[:n_docs] + + # Remove labels whose gold docs have been removed + doc_ids = [x.id for x in gold_docs] + labels = [x for x in labels if x.document_id in doc_ids] + + # Filter labels down to n_queries + selected_queries = list(set(f"{x.document_id} | {x.question}" for x in labels)) + selected_queries = selected_queries[:n_queries] + labels = [x for x in labels if f"{x.document_id} | {x.question}" in selected_queries] + + n_neg_docs = max(0, n_docs - len(gold_docs)) + neg_docs = prepare_negative_passages(data_dir, filename_negative, n_neg_docs) + docs = gold_docs + neg_docs + + if add_precomputed: + docs = add_precomputed_embeddings(data_dir / embeddings_dir, embeddings_filenames, docs) + + return docs, labels + +def prepare_negative_passages(data_dir, filename_negative, n_docs): + if n_docs == 0: + return [] + with open(data_dir / filename_negative) as f: + lines = [] + _ = f.readline() # Skip column titles line + for _ in range(n_docs): + lines.append(f.readline()[:-1]) + + docs = [] + for l in lines[:n_docs]: + id, text, title = l.split("\t") + d = {"text": text, + "meta": {"passage_id": int(id), + "title": title}} + d = Document(**d) + docs.append(d) + return docs + +def benchmark_indexing(): + + retriever_results = [] + for n_docs in n_docs_options: + for retriever_name, doc_store_name in retriever_doc_stores: + doc_store = get_document_store(doc_store_name, es_similarity=es_similarity) + + retriever = get_retriever(retriever_name, doc_store) + + docs, _ = prepare_data(data_dir, filename_gold, filename_negative, n_docs=n_docs) + + tic = perf_counter() + index_to_doc_store(doc_store, docs, retriever) + toc = perf_counter() + indexing_time = toc - tic + + print(indexing_time) + + retriever_results.append({ + "retriever": retriever_name, + "doc_store": doc_store_name, + "n_docs": n_docs, + "indexing_time": indexing_time, + "docs_per_second": n_docs / indexing_time, + "date_time": datetime.datetime.now()}) + retriever_df = pd.DataFrame.from_records(retriever_results) + retriever_df = retriever_df.sort_values(by="retriever").sort_values(by="doc_store") + retriever_df.to_csv("retriever_index_results.csv") + + doc_store.delete_all_documents(index=doc_index) + doc_store.delete_all_documents(index=label_index) + time.sleep(10) + del doc_store + del retriever + +def benchmark_querying(): + """ Benchmark the time it takes to perform querying. Doc embeddings are loaded from file.""" + retriever_results = [] + for n_docs in n_docs_options: + for retriever_name, doc_store_name in retriever_doc_stores: + try: + logger.info(f"##### Start run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ") + doc_store = get_document_store(doc_store_name, es_similarity=es_similarity) + retriever = get_retriever(retriever_name, doc_store) + add_precomputed = retriever_name in ["dpr"] + # For DPR, precomputed embeddings are loaded from file + docs, labels = prepare_data(data_dir, + filename_gold, + filename_negative, + n_docs=n_docs, + n_queries=n_queries, + add_precomputed=add_precomputed) + logger.info("Start indexing...") + index_to_doc_store(doc_store, docs, retriever, labels) + logger.info("Start queries...") + + raw_results = retriever.eval() + results = { + "retriever": retriever_name, + "doc_store": doc_store_name, + "n_docs": n_docs, + "n_queries": raw_results["n_questions"], + "retrieve_time": raw_results["retrieve_time"], + "queries_per_second": raw_results["n_questions"] / raw_results["retrieve_time"], + "seconds_per_query": raw_results["retrieve_time"]/ raw_results["n_questions"], + "recall": raw_results["recall"], + "map": raw_results["map"], + "top_k": raw_results["top_k"], + "date_time": datetime.datetime.now(), + "error": None + } + + doc_store.delete_all_documents() + time.sleep(5) + del doc_store + del retriever + except Exception as e: + tb = traceback.format_exc() + results = { + "retriever": retriever_name, + "doc_store": doc_store_name, + "n_docs": n_docs, + "n_queries": 0, + "retrieve_time": 0., + "queries_per_second": 0., + "seconds_per_query": 0., + "recall": 0., + "map": 0., + "top_k": 0, + "date_time": datetime.datetime.now(), + "error": str(tb) + } + logger.info(results) + retriever_results.append(results) + + retriever_df = pd.DataFrame.from_records(retriever_results) + retriever_df = retriever_df.sort_values(by="retriever").sort_values(by="doc_store") + retriever_df.to_csv("retriever_query_results.csv") + + + + +def add_precomputed_embeddings(embeddings_dir, embeddings_filenames, docs): + ret = [] + id_to_doc = {x.meta["passage_id"]: x for x in docs} + for ef in embeddings_filenames: + logger.info(f"Adding precomputed embeddings from {embeddings_dir / ef}") + filename = embeddings_dir / ef + embeds = pickle.load(open(filename, "rb")) + for i, vec in embeds: + if int(i) in id_to_doc: + curr = id_to_doc[int(i)] + curr.embedding = vec + ret.append(curr) + # In the official DPR repo, there are only 20594995 precomputed embeddings for 21015324 wikipedia passages + # If there isn't an embedding for a given doc, we remove it here + ret = [x for x in ret if x.embedding is not None] + logger.info(f"Embeddings loaded for {len(ret)}/{len(docs)} docs") + return ret + + +if __name__ == "__main__": + # benchmark_indexing() + benchmark_querying() diff --git a/test/benchmarks/retriever_index_results.csv b/test/benchmarks/retriever_index_results.csv new file mode 100644 index 000000000..7a488ccf1 --- /dev/null +++ b/test/benchmarks/retriever_index_results.csv @@ -0,0 +1,13 @@ +retriever,doc_store,n_docs,indexing_time,docs_per_second,date_time,Notes +dpr,elasticsearch,1000,14.16526405,70.59522482,2020-10-08 10:30:56, +elastic,elasticsearch,1000,5.805040058,172.2640998,2020-10-08 10:30:25, +elastic,elasticsearch,10000,22.56448254,443.1743553,2020-10-08 13:01:09, +dpr,elasticsearch,10000,126.2442168,79.21154929,2020-10-08 13:03:32, +dpr,elasticsearch,100000,1257.202958,79.54165185,2020-10-08 13:28:16, +elastic,elasticsearch,100000,209.681252,476.9143596,2020-10-08 13:07:05, +dpr,faiss_flat,1000,8.223732258,121.5992895,44112.24392, +dpr,faiss_flat,10000,89.72649358,111.4498026,44112.24663, +dpr,faiss_flat,100000,927.0740565,107.8662479,44112.56656, +dpr,faiss_hnsw,1000,8.86507699,112.8021788,44113.37262,"hnsw 128,20,80" +dpr,faiss_hnsw,10000,100.1804832,99.81984193,44113.37413,"hnsw 128,20,80" +dpr,faiss_hnsw,100000,1084.063917,92.24548333,44113.38721,"hnsw 128,20,80" \ No newline at end of file diff --git a/test/benchmarks/retriever_query_results.csv b/test/benchmarks/retriever_query_results.csv new file mode 100644 index 000000000..8fc1ed958 --- /dev/null +++ b/test/benchmarks/retriever_query_results.csv @@ -0,0 +1,17 @@ +retriever,doc_store,n_docs,n_queries,retrieve_time,queries_per_second,seconds_per_query,recall,map,top_k,date_time,error,name,note +dpr,elasticsearch,1000,1085,26.592,40.802,0.025,0.991,0.929,10,2020-10-07 15:06:57,,dpr-elasticsearch, +dpr,elasticsearch,10000,5791,214.425,27.007,0.037,0.975,0.898,10,2020-10-07 15:11:35,,dpr-elasticsearch, +dpr,elasticsearch,100000,5791,886.045,6.536,0.153,0.958,0.863,10,2020-10-07 15:30:52,,dpr-elasticsearch, +dpr,elasticsearch,500000,5791,3824.624,1.514,0.660,0.930,0.805,10,2020-10-07 17:44:02,,dpr-elasticsearch, +dpr,faiss_flat,1000,1085,27.092,40.048,0.025,0.991,0.929,10,2020-10-07 13:06:35,,dpr-faiss_flat, +dpr,faiss_flat,10000,5791,241.524,23.977,0.042,0.975,0.898,10,2020-10-07 13:17:21,,dpr-faiss_flat, +dpr,faiss_flat,100000,5791,1148.181,5.044,0.198,0.958,0.863,10,2020-10-07 14:04:51,,dpr-faiss_flat, +dpr,faiss_flat,500000,5791,5308.016,1.091,0.917,0.930,0.805,10,2020-10-08 10:01:32,,dpr-faiss_flat, +elastic,elasticsearch,1000,1085,4.657,232.978,0.004,0.891,0.748,10,2020-10-07 13:04:47,,elastic-elasticsearch, +elastic,elasticsearch,10000,5791,34.509,167.810,0.006,0.811,0.661,10,2020-10-07 13:07:52,,elastic-elasticsearch, +elastic,elasticsearch,100000,5791,35.529,162.996,0.006,0.717,0.560,10,2020-10-07 13:21:48,,elastic-elasticsearch, +elastic,elasticsearch,500000,5791,60.645,95.491,0.010,0.624,0.452,10,2020-10-07 16:14:52,,elastic-elasticsearch, +dpr,faiss_hnsw,1000,1085,28.640,37.884,0.026,0.991,0.929,10,2020-10-09 07:19:29,,dpr-faiss_hnsw,"128,20,80" +dpr,faiss_hnsw,10000,5791,173.272,33.421,0.030,0.972,0.896,10,2020-10-09 07:23:28,,dpr-faiss_hnsw,"128,20,80" +dpr,faiss_hnsw,100000,5791,451.884,12.815,0.078,0.940,0.849,10,2020-10-09 07:37:56,,dpr-faiss_hnsw,"128,20,80" +dpr,faiss_hnsw,500000,5791,1777.023,3.259,0.307,0.882,0.766,10,2020-10-09,,dpr-faiss_hnsw,"128,20,80" \ No newline at end of file diff --git a/test/benchmarks/run.py b/test/benchmarks/run.py new file mode 100644 index 000000000..d318fd529 --- /dev/null +++ b/test/benchmarks/run.py @@ -0,0 +1,24 @@ +from retriever import benchmark_indexing, benchmark_querying +from reader import benchmark_reader +import argparse + +parser = argparse.ArgumentParser() + +parser.add_argument('--reader', default=False, action="store_true", + help='Perform Reader benchmarks') +parser.add_argument('--retriever_index', default=False, action="store_true", + help='Perform Retriever indexing benchmarks') +parser.add_argument('--retriever_query', default=False, action="store_true", + help='Perform Retriever querying benchmarks') +parser.add_argument('--ci', default=False, action="store_true", + help='Perform a smaller subset of benchmarks that are quicker to run') + +args = parser.parse_args() + +if args.retriever_index: + benchmark_indexing(ci) +if args.retriever_query: + benchmark_querying(ci) +if args.retriever_reader: + benchmark_reader(ci) + diff --git a/test/benchmarks/utils.py b/test/benchmarks/utils.py new file mode 100644 index 000000000..c8edf0dd7 --- /dev/null +++ b/test/benchmarks/utils.py @@ -0,0 +1,106 @@ +import os +from haystack import Document +from haystack.document_store.sql import SQLDocumentStore +from haystack.document_store.memory import InMemoryDocumentStore +from haystack.document_store.elasticsearch import Elasticsearch, ElasticsearchDocumentStore +from haystack.document_store.faiss import FAISSDocumentStore +from haystack.retriever.sparse import ElasticsearchRetriever, TfidfRetriever +from haystack.retriever.dense import DensePassageRetriever +from haystack.reader.farm import FARMReader +from haystack.reader.transformers import TransformersReader +from time import perf_counter +import pandas as pd +import json +import logging +import subprocess +import time + +from pathlib import Path +logger = logging.getLogger(__name__) + + +reader_models = ["deepset/roberta-base-squad2", "deepset/minilm-uncased-squad2", "deepset/bert-base-cased-squad2", "deepset/bert-large-uncased-whole-word-masking-squad2", "deepset/xlm-roberta-large-squad2"] +reader_types = ["farm"] +data_dir_reader = Path("../../data/squad20") +filename_reader = "dev-v2.0.json" + +doc_index = "eval_document" +label_index = "label" + +def get_document_store(document_store_type, es_similarity='cosine'): + """ TODO This method is taken from test/conftest.py but maybe should be within Haystack. + Perhaps a class method of DocStore that just takes string for type of DocStore""" + if document_store_type == "sql": + if os.path.exists("haystack_test.db"): + os.remove("haystack_test.db") + document_store = SQLDocumentStore(url="sqlite:///haystack_test.db") + elif document_store_type == "memory": + document_store = InMemoryDocumentStore() + elif document_store_type == "elasticsearch": + # make sure we start from a fresh index + client = Elasticsearch() + client.indices.delete(index='haystack_test*', ignore=[404]) + document_store = ElasticsearchDocumentStore(index="eval_document", similarity=es_similarity) + elif document_store_type in("faiss_flat", "faiss_hnsw"): + if document_store_type == "faiss_flat": + index_type = "Flat" + elif document_store_type == "faiss_hnsw": + index_type = "HNSW" + + #TEMP FIX for issue with deleting docs + # status = subprocess.run( + # ['docker rm -f haystack-postgres'], + # shell=True) + # time.sleep(3) + # try: + # document_store = FAISSDocumentStore(sql_url="postgresql://postgres:password@localhost:5432/haystack", + # faiss_index_factory_str=index_type) + # except: + # Launch a postgres instance & create empty DB + # logger.info("Didn't find Postgres. Start a new instance...") + status = subprocess.run( + ['docker rm -f haystack-postgres'], + shell=True) + time.sleep(1) + status = subprocess.run( + ['docker run --name haystack-postgres -p 5432:5432 -e POSTGRES_PASSWORD=password -d postgres'], + shell=True) + time.sleep(3) + status = subprocess.run( + ['docker exec -it haystack-postgres psql -U postgres -c "CREATE DATABASE haystack;"'], shell=True) + time.sleep(1) + document_store = FAISSDocumentStore(sql_url="postgresql://postgres:password@localhost:5432/haystack", + faiss_index_factory_str=index_type) + + else: + raise Exception(f"No document store fixture for '{document_store_type}'") + return document_store + +def get_retriever(retriever_name, doc_store): + if retriever_name == "elastic": + return ElasticsearchRetriever(doc_store) + if retriever_name == "tfidf": + return TfidfRetriever(doc_store) + if retriever_name == "dpr": + return DensePassageRetriever(document_store=doc_store, + query_embedding_model="facebook/dpr-question_encoder-single-nq-base", + passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base", + use_gpu=True) + +def get_reader(reader_name, reader_type, max_seq_len=384): + reader_class = None + if reader_type == "farm": + reader_class = FARMReader + elif reader_type == "transformers": + reader_class = TransformersReader + return reader_class(reader_name, top_k_per_candidate=4, max_seq_len=max_seq_len) + +def index_to_doc_store(doc_store, docs, retriever, labels=None): + doc_store.write_documents(docs, doc_index) + if labels: + doc_store.write_labels(labels, index=label_index) + # these lines are not run if the docs.embedding field is already populated with precomputed embeddings + # See the prepare_data() fn in the retriever benchmark script + elif callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None: + doc_store.update_embeddings(retriever, index=doc_index) +