mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 06:58:35 +00:00
Integrate sentence transformers into benchmarks (#843)
* Integrate sentence transformers into benchmarks * Add doc store asserts * switch data downloads from s3 client to https. add license info * Fix mypy, revert config Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
d38c07e0ee
commit
837dea4e6d
@ -462,7 +462,7 @@ that are most relevant to the query.
|
||||
#### embed
|
||||
|
||||
```python
|
||||
| embed(texts: Union[List[str], str]) -> List[np.ndarray]
|
||||
| embed(texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]
|
||||
```
|
||||
|
||||
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
||||
@ -496,7 +496,7 @@ Embeddings, one per input queries
|
||||
#### embed\_passages
|
||||
|
||||
```python
|
||||
| embed_passages(docs: List[Document]) -> List[np.ndarray]
|
||||
| embed_passages(docs: List[Document]) -> Union[List[str], List[List[str]]]
|
||||
```
|
||||
|
||||
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()
|
||||
|
||||
@ -524,7 +524,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
top_k=top_k, index=index)
|
||||
return documents
|
||||
|
||||
def embed(self, texts: Union[List[str], str]) -> List[np.ndarray]:
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
|
||||
"""
|
||||
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
||||
|
||||
@ -543,9 +543,9 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
|
||||
emb = [(r["vec"]) for r in emb]
|
||||
elif self.model_format == "sentence_transformers":
|
||||
# text is single string, sentence-transformers needs a list of strings
|
||||
# texts can be a list of strings or a list of [title, text]
|
||||
# get back list of numpy embedding vectors
|
||||
emb = self.embedding_model.encode(texts)
|
||||
emb = self.embedding_model.encode(texts, batch_size=200, show_progress_bar=False)
|
||||
emb = [r for r in emb]
|
||||
return emb
|
||||
|
||||
@ -558,13 +558,15 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"""
|
||||
return self.embed(texts)
|
||||
|
||||
def embed_passages(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
def embed_passages(self, docs: List[Document]) -> Union[List[str], List[List[str]]]:
|
||||
"""
|
||||
Create embeddings for a list of passages. For this Retriever type: The same as calling .embed()
|
||||
|
||||
:param docs: List of documents to embed
|
||||
:return: Embeddings, one per input passage
|
||||
"""
|
||||
texts = [d.text for d in docs]
|
||||
|
||||
return self.embed(texts)
|
||||
if self.model_format == "sentence_transformers":
|
||||
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.text] for d in docs] # type: ignore
|
||||
else:
|
||||
passages = [d.text for d in docs] # type: ignore
|
||||
return self.embed(passages)
|
||||
|
||||
@ -6,6 +6,10 @@
|
||||
"elastic",
|
||||
"elasticsearch"
|
||||
],
|
||||
[
|
||||
"sentence_transformers",
|
||||
"elasticsearch"
|
||||
],
|
||||
[
|
||||
"dpr",
|
||||
"elasticsearch"
|
||||
@ -41,7 +45,7 @@
|
||||
}
|
||||
},
|
||||
"filenames": {
|
||||
"data_s3_url": "s3://ext-haystack-retriever-eval/",
|
||||
"data_s3_url": "https://ext-haystack-retriever-eval.s3-eu-west-1.amazonaws.com/",
|
||||
"data_dir": "../../data/retriever/",
|
||||
"filename_gold": "nq2squad-dev.json",
|
||||
"filenames_negative": {
|
||||
@ -55,4 +59,4 @@
|
||||
"100000": "wikipedia_passages_100k.pkl",
|
||||
"1000000": "wikipedia_passages_1m.pkl"}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from time import perf_counter
|
||||
from utils import get_document_store, get_retriever, index_to_doc_store, load_config
|
||||
from utils import get_document_store, get_retriever, index_to_doc_store, load_config, download_from_url
|
||||
from haystack.preprocessor.utils import eval_data_from_json
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
|
||||
@ -15,7 +15,7 @@ import random
|
||||
import traceback
|
||||
import os
|
||||
import requests
|
||||
from farm.file_utils import download_from_s3
|
||||
from farm.file_utils import http_get
|
||||
import json
|
||||
from results_to_json import retriever as retriever_json
|
||||
from templates import RETRIEVER_TEMPLATE, RETRIEVER_MAP_TEMPLATE, RETRIEVER_SPEED_TEMPLATE
|
||||
@ -50,7 +50,7 @@ def benchmark_indexing(n_docs_options, retriever_doc_stores, data_dir, filename_
|
||||
docs, _ = prepare_data(data_dir=data_dir,
|
||||
filename_gold=filename_gold,
|
||||
filename_negative=filename_negative,
|
||||
data_s3_url=data_s3_url,
|
||||
remote_url=data_s3_url,
|
||||
embeddings_filenames=embeddings_filenames,
|
||||
embeddings_dir=embeddings_dir,
|
||||
n_docs=n_docs)
|
||||
@ -134,7 +134,7 @@ def benchmark_querying(n_docs_options,
|
||||
for retriever_name, doc_store_name in retriever_doc_stores:
|
||||
try:
|
||||
logger.info(f"##### Start querying run: {retriever_name}, {doc_store_name}, {n_docs} docs ##### ")
|
||||
if retriever_name == "elastic":
|
||||
if retriever_name in ["elastic", "sentence_transformers"]:
|
||||
similarity = "cosine"
|
||||
else:
|
||||
similarity = "dot_product"
|
||||
@ -145,7 +145,7 @@ def benchmark_querying(n_docs_options,
|
||||
docs, labels = prepare_data(data_dir=data_dir,
|
||||
filename_gold=filename_gold,
|
||||
filename_negative=filename_negative,
|
||||
data_s3_url=data_s3_url,
|
||||
remote_url=data_s3_url,
|
||||
embeddings_filenames=embeddings_filenames,
|
||||
embeddings_dir=embeddings_dir,
|
||||
n_docs=n_docs,
|
||||
@ -254,7 +254,7 @@ def add_precomputed_embeddings(embeddings_dir, embeddings_filenames, docs):
|
||||
return ret
|
||||
|
||||
|
||||
def prepare_data(data_dir, filename_gold, filename_negative, data_s3_url, embeddings_filenames, embeddings_dir, n_docs=None, n_queries=None, add_precomputed=False):
|
||||
def prepare_data(data_dir, filename_gold, filename_negative, remote_url, embeddings_filenames, embeddings_dir, n_docs=None, n_queries=None, add_precomputed=False):
|
||||
"""
|
||||
filename_gold points to a squad format file.
|
||||
filename_negative points to a csv file where the first column is doc_id and second is document text.
|
||||
@ -262,11 +262,11 @@ def prepare_data(data_dir, filename_gold, filename_negative, data_s3_url, embed
|
||||
"""
|
||||
|
||||
logging.getLogger("farm").setLevel(logging.INFO)
|
||||
download_from_s3(data_s3_url + filename_gold, cache_dir=data_dir)
|
||||
download_from_s3(data_s3_url + filename_negative, cache_dir=data_dir)
|
||||
download_from_url(remote_url + filename_gold, filepath=data_dir + filename_gold)
|
||||
download_from_url(remote_url + filename_negative, filepath=data_dir + filename_negative)
|
||||
if add_precomputed:
|
||||
for embedding_filename in embeddings_filenames:
|
||||
download_from_s3(data_s3_url + str(embeddings_dir) + embedding_filename, cache_dir=data_dir)
|
||||
download_from_url(remote_url + str(embeddings_dir) + embedding_filename, filepath=data_dir + str(embeddings_dir) + embedding_filename)
|
||||
logging.getLogger("farm").setLevel(logging.WARN)
|
||||
|
||||
gold_docs, labels = eval_data_from_json(data_dir + filename_gold)
|
||||
|
||||
@ -1,3 +1,9 @@
|
||||
# The benchmarks use
|
||||
# - a variant of the Natural Questions Dataset (https://ai.google.com/research/NaturalQuestions) from Google Research
|
||||
# licensed under CC BY-SA 3.0 (https://creativecommons.org/licenses/by-sa/3.0/)
|
||||
# - the SQuAD 2.0 Dataset (https://rajpurkar.github.io/SQuAD-explorer/) from Rajpurkar et al.
|
||||
# licensed under CC BY-SA 4.0 (https://creativecommons.org/licenses/by-sa/4.0/legalcode)
|
||||
|
||||
from retriever import benchmark_indexing, benchmark_querying
|
||||
from reader import benchmark_reader
|
||||
from utils import load_config
|
||||
|
||||
@ -4,14 +4,16 @@ from haystack.document_store.memory import InMemoryDocumentStore
|
||||
from haystack.document_store.elasticsearch import Elasticsearch, ElasticsearchDocumentStore
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever, TfidfRetriever
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.retriever.dense import DensePassageRetriever, EmbeddingRetriever
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.reader.transformers import TransformersReader
|
||||
from farm.file_utils import http_get
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,6 +31,7 @@ def get_document_store(document_store_type, similarity='dot_product'):
|
||||
if os.path.exists("haystack_test.db"):
|
||||
os.remove("haystack_test.db")
|
||||
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
|
||||
assert document_store.get_document_count() == 0
|
||||
elif document_store_type == "memory":
|
||||
document_store = InMemoryDocumentStore()
|
||||
elif document_store_type == "elasticsearch":
|
||||
@ -36,6 +39,7 @@ def get_document_store(document_store_type, similarity='dot_product'):
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test*', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="eval_document", similarity=similarity, timeout=3000)
|
||||
assert document_store.get_document_count(index="eval_document") == 0
|
||||
elif document_store_type in("faiss_flat", "faiss_hnsw"):
|
||||
if document_store_type == "faiss_flat":
|
||||
index_type = "Flat"
|
||||
@ -55,10 +59,10 @@ def get_document_store(document_store_type, similarity='dot_product'):
|
||||
document_store = FAISSDocumentStore(sql_url="postgresql://postgres:password@localhost:5432/haystack",
|
||||
faiss_index_factory_str=index_type,
|
||||
similarity=similarity)
|
||||
assert document_store.get_document_count() == 0
|
||||
|
||||
else:
|
||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||
assert document_store.get_document_count() == 0
|
||||
return document_store
|
||||
|
||||
def get_retriever(retriever_name, doc_store):
|
||||
@ -72,6 +76,11 @@ def get_retriever(retriever_name, doc_store):
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=True,
|
||||
use_fast_tokenizers=False)
|
||||
if retriever_name == "sentence_transformers":
|
||||
return EmbeddingRetriever(document_store=doc_store,
|
||||
embedding_model="nq-distilbert-base-v1",
|
||||
use_gpu=True,
|
||||
model_format="sentence_transformers")
|
||||
|
||||
def get_reader(reader_name, reader_type, max_seq_len=384):
|
||||
reader_class = None
|
||||
@ -87,7 +96,7 @@ def index_to_doc_store(doc_store, docs, retriever, labels=None):
|
||||
doc_store.write_labels(labels, index=label_index)
|
||||
# these lines are not run if the docs.embedding field is already populated with precomputed embeddings
|
||||
# See the prepare_data() fn in the retriever benchmark script
|
||||
elif callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
|
||||
if callable(getattr(retriever, "embed_passages", None)) and docs[0].embedding is None:
|
||||
doc_store.update_embeddings(retriever, index=doc_index)
|
||||
|
||||
def load_config(config_filename, ci):
|
||||
@ -107,3 +116,25 @@ def load_config(config_filename, ci):
|
||||
return params, filenames
|
||||
|
||||
|
||||
def download_from_url(url: str, filepath:Union[str, Path]):
|
||||
"""
|
||||
Download from a url to a local file. Skip already existing files.
|
||||
|
||||
:param url: Url
|
||||
:param filepath: local path where the url content shall be stored
|
||||
:return: local path of the downloaded file
|
||||
"""
|
||||
|
||||
logger.info(f"Downloading {url}")
|
||||
# Create local folder
|
||||
folder, filename = os.path.split(filepath)
|
||||
if not os.path.exists(folder):
|
||||
os.makedirs(folder)
|
||||
# Download file if not present locally
|
||||
if os.path.exists(filepath):
|
||||
logger.info(f"Skipping {url} (exists locally)")
|
||||
else:
|
||||
logger.info(f"Downloading {url} to {filepath} ")
|
||||
with open(filepath, "wb") as file:
|
||||
http_get(url=url, temp_file=file)
|
||||
return filepath
|
||||
Loading…
x
Reference in New Issue
Block a user