Integrate BEIR (#2333)

* introduce eval_beir() to Pipeline

* add beir dependency

* Update Documentation & Code Style

* top_k_values added + refactoring

* Update Documentation & Code Style

* enable titles during beir eval

* Update Documentation & Code Style

* raise HaystackError instead of PipelineError

* get rid of forced dedicated index

* minor docstring and comment fixes

* show warning on default index deletion

* Update Documentation & Code Style

* add delete_index to MockDocumentStore

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
tstadel 2022-03-21 19:04:28 +01:00 committed by GitHub
parent 98fa48cc4c
commit ca86cc834d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 428 additions and 14 deletions

View File

@ -250,6 +250,25 @@ When set to None (default) all available eval documents are used.
same question might be found in different contexts.
- `headers`: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
<a id="base.BaseDocumentStore.delete_index"></a>
#### delete\_index
```python
@abstractmethod
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="base.BaseDocumentStore.run"></a>
#### run
@ -1842,6 +1861,24 @@ Example:
None
<a id="memory.InMemoryDocumentStore.delete_index"></a>
#### delete\_index
```python
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="memory.InMemoryDocumentStore.delete_labels"></a>
#### delete\_labels
@ -2127,6 +2164,24 @@ have their ID in the list).
None
<a id="sql.SQLDocumentStore.delete_index"></a>
#### delete\_index
```python
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="sql.SQLDocumentStore.delete_labels"></a>
#### delete\_labels
@ -2371,6 +2426,24 @@ have their ID in the list).
None
<a id="faiss.FAISSDocumentStore.delete_index"></a>
#### delete\_index
```python
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="faiss.FAISSDocumentStore.query_by_embedding"></a>
#### query\_by\_embedding
@ -2641,6 +2714,24 @@ have their ID in the list).
None
<a id="milvus1.Milvus1DocumentStore.delete_index"></a>
#### delete\_index
```python
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="milvus1.Milvus1DocumentStore.get_all_documents_generator"></a>
#### get\_all\_documents\_generator
@ -2932,6 +3023,24 @@ Example: {"name": ["some", "more"], "category": ["only_one"]}
None
<a id="milvus2.Milvus2DocumentStore.delete_index"></a>
#### delete\_index
```python
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="milvus2.Milvus2DocumentStore.get_all_documents_generator"></a>
#### get\_all\_documents\_generator
@ -3565,6 +3674,24 @@ operation.
None
<a id="weaviate.WeaviateDocumentStore.delete_index"></a>
#### delete\_index
```python
def delete_index(index: str)
```
Delete an existing index. The index including all data will be removed.
**Arguments**:
- `index`: The name of the index to delete.
**Returns**:
None
<a id="weaviate.WeaviateDocumentStore.delete_labels"></a>
#### delete\_labels

View File

@ -413,7 +413,7 @@ Set the component for a node in the Pipeline.
#### run
```python
def run(query: Optional[str] = None, file_paths: Optional[List[str]] = None, labels: Optional[MultiLabel] = None, documents: Optional[List[Document]] = None, meta: Optional[dict] = None, params: Optional[dict] = None, debug: Optional[bool] = None)
def run(query: Optional[str] = None, file_paths: Optional[List[str]] = None, labels: Optional[MultiLabel] = None, documents: Optional[List[Document]] = None, meta: Optional[Union[dict, List[dict]]] = None, params: Optional[dict] = None, debug: Optional[bool] = None)
```
Runs the pipeline, one node at a time.
@ -434,6 +434,35 @@ about their execution. By default these include the input parameters
they received and the output they generated. All debug information can
then be found in the dict returned by this method under the key "_debug"
<a id="base.Pipeline.eval_beir"></a>
#### eval\_beir
```python
@classmethod
def eval_beir(cls, index_pipeline: Pipeline, query_pipeline: Pipeline, index_params: dict = {}, query_params: dict = {}, dataset: str = "scifact", dataset_dir: Path = Path("."), top_k_values: List[int] = [1, 3, 5, 10, 100, 1000], keep_index: bool = False) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]
```
Runs information retrieval evaluation of a pipeline using BEIR on a specified BEIR dataset.
See https://github.com/beir-cellar/beir for more information.
**Arguments**:
- `index_pipeline`: The indexing pipeline to use.
- `query_pipeline`: The query pipeline to evaluate.
- `index_params`: The params to use during indexing (see pipeline.run's params).
- `query_params`: The params to use during querying (see pipeline.run's params).
- `dataset`: The BEIR dataset to use.
- `dataset_dir`: The directory to store the dataset to.
- `top_k_values`: The top_k values each metric will be calculated for.
- `keep_index`: Whether to keep the index after evaluation.
If True the index will be kept after beir evaluation. Otherwise it will be deleted immediately afterwards.
Defaults to False.
Returns a tuple containing the ncdg, map, recall and precision scores.
Each metric is represented by a dictionary containing the scores for each top_k value.
<a id="base.Pipeline.eval"></a>
#### eval
@ -835,6 +864,34 @@ def __call__(*args, **kwargs)
Ray calls this method which is then re-directed to the corresponding component's run().
<a id="base._HaystackBeirRetrieverAdapter"></a>
## \_HaystackBeirRetrieverAdapter
```python
class _HaystackBeirRetrieverAdapter()
```
<a id="base._HaystackBeirRetrieverAdapter.__init__"></a>
#### \_\_init\_\_
```python
def __init__(index_pipeline: Pipeline, query_pipeline: Pipeline, index_params: dict, query_params: dict)
```
Adapter mimicking a BEIR retriever used by BEIR's EvaluateRetrieval class to run BEIR evaluations on Haystack Pipelines.
This has nothing to do with Haystack's retriever classes.
See https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py.
**Arguments**:
- `index_pipeline`: The indexing pipeline to use.
- `query_pipeline`: The query pipeline to evaluate.
- `index_params`: The params to use during indexing (see pipeline.run's params).
- `query_params`: The params to use during querying (see pipeline.run's params).
<a id="standard_pipelines"></a>
# Module standard\_pipelines

View File

@ -512,6 +512,16 @@ class BaseDocumentStore(BaseComponent):
):
pass
@abstractmethod
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
pass
@abstractmethod
def _create_document_field_map(self) -> Dict:
pass

View File

@ -485,3 +485,6 @@ class DeepsetCloudDocumentStore(KeywordDocumentStore):
headers: Optional[Dict[str, str]] = None,
):
raise NotImplementedError("DeepsetCloudDocumentStore currently does not support labels.")
def delete_index(self, index: str):
raise NotImplementedError("DeepsetCloudDocumentStore currently does not support deleting indexes.")

View File

@ -318,12 +318,11 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
if self.search_fields:
for search_field in self.search_fields:
if search_field in mapping["properties"] and mapping["properties"][search_field]["type"] != "text":
host_data = self.client.transport.hosts[0]
raise Exception(
f"The search_field '{search_field}' of index '{index_name}' with type '{mapping['properties'][search_field]['type']}' "
f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields. "
f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack."
f"In this case deleting the index with `curl -X DELETE \"{host_data['host']}:{host_data['port']}/{index_name}\"` will fix your environment. "
f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields or use another index. "
f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack. "
f'In this case deleting the index with `delete_index(index="{index_name}")` will fix your environment. '
f"Note, that all data stored in the index will be lost!"
)
if self.embedding_field:
@ -1571,6 +1570,11 @@ class ElasticsearchDocumentStore(KeywordDocumentStore):
:param index: The name of the index to delete.
:return: None
"""
if index == self.index:
logger.warning(
f"Deletion of default index '{index}' detected. "
f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects."
)
self.client.indices.delete(index=index, ignore=[400, 404])
logger.debug(f"deleted elasticsearch index {index}")
@ -1790,12 +1794,11 @@ class OpenSearchDocumentStore(ElasticsearchDocumentStore):
search_field in mappings["properties"]
and mappings["properties"][search_field]["type"] != "text"
):
host_data = self.client.transport.hosts[0]
raise Exception(
f"The search_field '{search_field}' of index '{index_name}' with type '{mappings['properties'][search_field]['type']}' "
f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields. "
f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack."
f"In this case deleting the index with `curl -X DELETE \"{host_data['host']}:{host_data['port']}/{index_name}\"` will fix your environment. "
f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields or use another index. "
f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack. "
f'In this case deleting the index with `delete_index(index="{index_name}")` will fix your environment. '
f"Note, that all data stored in the index will be lost!"
)

View File

@ -525,6 +525,21 @@ class FAISSDocumentStore(SQLDocumentStore):
super().delete_documents(index=index, ids=ids, filters=filters)
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
if index == self.index:
logger.warning(
f"Deletion of default index '{index}' detected. "
f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects."
)
del self.faiss_indexes[index]
super().delete_index(index)
def query_by_embedding(
self,
query_emb: np.ndarray,

View File

@ -743,6 +743,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
for doc in docs_to_delete:
del self.indexes[index][doc.id]
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
if index == self.index:
logger.warning(
f"Deletion of default index '{index}' detected. "
f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects."
)
del self.indexes[index]
def delete_labels(
self,
index: Optional[str] = None,

View File

@ -483,6 +483,21 @@ class Milvus1DocumentStore(SQLDocumentStore):
# Delete from SQL at the end to allow the above .get_all_documents() to work properly
super().delete_documents(index=index, ids=ids, filters=filters)
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
if index == self.index:
logger.warning(
f"Deletion of default index '{index}' detected. "
f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects."
)
self.milvus_server.drop_collection(index)
super().delete_index(index)
def get_all_documents_generator(
self,
index: Optional[str] = None,

View File

@ -474,6 +474,21 @@ class Milvus2DocumentStore(SQLDocumentStore):
index = index or self.index
super().delete_documents(index=index, filters=filters, ids=ids)
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
if index == self.index:
logger.warning(
f"Deletion of default index '{index}' detected. "
f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects."
)
utility.drop_collection(collection_name=index)
super().delete_index(index)
def get_all_documents_generator(
self,
index: Optional[str] = None,

View File

@ -658,6 +658,15 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.commit()
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
self.delete_documents(index)
def delete_labels(
self,
index: Optional[str] = None,

View File

@ -1146,6 +1146,20 @@ class WeaviateDocumentStore(BaseDocumentStore):
for doc in docs_to_delete:
self.weaviate_client.data_object.delete(doc.id)
def delete_index(self, index: str):
"""
Delete an existing index. The index including all data will be removed.
:param index: The name of the index to delete.
:return: None
"""
if index == self.index:
logger.warning(
f"Deletion of default index '{index}' detected. "
f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects."
)
self.weaviate_client.schema.delete_class(index)
def delete_labels(self):
"""
Implemented to respect BaseDocumentStore's contract.

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from os import pipe
from typing import Dict, List, Optional, Any, Set
import tempfile
from typing import Dict, List, Optional, Any, Set, Tuple, Union
import copy
import json
@ -16,6 +17,7 @@ from jsonschema import Draft7Validator
from jsonschema.exceptions import ValidationError
from jsonschema import _utils as jsonschema_utils
from pandas.core.frame import DataFrame
from tqdm import tqdm
from transformers import pipelines
import yaml
from networkx import DiGraph
@ -45,7 +47,7 @@ except:
from haystack import __version__
from haystack.schema import EvaluationResult, MultiLabel, Document
from haystack.errors import PipelineError, PipelineConfigError
from haystack.errors import HaystackError, PipelineError, PipelineConfigError
from haystack.nodes.base import BaseComponent
from haystack.nodes.retriever.base import BaseRetriever
from haystack.document_stores.base import BaseDocumentStore
@ -576,7 +578,7 @@ class Pipeline(BasePipeline):
file_paths: Optional[List[str]] = None,
labels: Optional[MultiLabel] = None,
documents: Optional[List[Document]] = None,
meta: Optional[dict] = None,
meta: Optional[Union[dict, List[dict]]] = None,
params: Optional[dict] = None,
debug: Optional[bool] = None,
):
@ -691,6 +693,83 @@ class Pipeline(BasePipeline):
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
return node_output
@classmethod
def eval_beir(
cls,
index_pipeline: Pipeline,
query_pipeline: Pipeline,
index_params: dict = {},
query_params: dict = {},
dataset: str = "scifact",
dataset_dir: Path = Path("."),
top_k_values: List[int] = [1, 3, 5, 10, 100, 1000],
keep_index: bool = False,
) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:
"""
Runs information retrieval evaluation of a pipeline using BEIR on a specified BEIR dataset.
See https://github.com/beir-cellar/beir for more information.
:param index_pipeline: The indexing pipeline to use.
:param query_pipeline: The query pipeline to evaluate.
:param index_params: The params to use during indexing (see pipeline.run's params).
:param query_params: The params to use during querying (see pipeline.run's params).
:param dataset: The BEIR dataset to use.
:param dataset_dir: The directory to store the dataset to.
:param top_k_values: The top_k values each metric will be calculated for.
:param keep_index: Whether to keep the index after evaluation.
If True the index will be kept after beir evaluation. Otherwise it will be deleted immediately afterwards.
Defaults to False.
Returns a tuple containing the ncdg, map, recall and precision scores.
Each metric is represented by a dictionary containing the scores for each top_k value.
"""
try:
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
except ModuleNotFoundError as e:
raise HaystackError("beir is not installed. Please run `pip install farm-haystack[beir]`...") from e
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
data_path = util.download_and_unzip(url, dataset_dir)
logger.info(f"Dataset downloaded here: {data_path}")
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") # or split = "train" or "dev"
# check index before eval
document_store = index_pipeline.get_document_store()
if document_store is not None:
if document_store.get_document_count() > 0:
raise HaystackError(f"Index '{document_store.index}' is not empty. Please provide an empty index.")
if hasattr(document_store, "search_fields"):
search_fields = getattr(document_store, "search_fields")
if "name" not in search_fields:
logger.warning(
"Field 'name' is not part of your DocumentStore's search_fields. Titles won't be searchable. "
"Please set search_fields appropriately."
)
haystack_retriever = _HaystackBeirRetrieverAdapter(
index_pipeline=index_pipeline,
query_pipeline=query_pipeline,
index_params=index_params,
query_params=query_params,
)
retriever = EvaluateRetrieval(haystack_retriever, k_values=top_k_values)
# Retrieve results (format of results is identical to qrels)
results = retriever.retrieve(corpus, queries)
# Clean up document store
if not keep_index and document_store is not None and document_store.index is not None:
logger.info(f"Cleaning up: deleting index '{document_store.index}'...")
document_store.delete_index(document_store.index)
# Evaluate your retrieval using NDCG@k, MAP@K ...
logger.info(f"Retriever evaluation for k in: {retriever.k_values}")
ndcg, map_, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
return ndcg, map_, recall, precision
@send_event
def eval(
self,
@ -1561,3 +1640,51 @@ class _RayDeploymentWrapper:
Ray calls this method which is then re-directed to the corresponding component's run().
"""
return self.node._dispatch_run(*args, **kwargs)
class _HaystackBeirRetrieverAdapter:
def __init__(self, index_pipeline: Pipeline, query_pipeline: Pipeline, index_params: dict, query_params: dict):
"""
Adapter mimicking a BEIR retriever used by BEIR's EvaluateRetrieval class to run BEIR evaluations on Haystack Pipelines.
This has nothing to do with Haystack's retriever classes.
See https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py.
:param index_pipeline: The indexing pipeline to use.
:param query_pipeline: The query pipeline to evaluate.
:param index_params: The params to use during indexing (see pipeline.run's params).
:param query_params: The params to use during querying (see pipeline.run's params).
"""
self.index_pipeline = index_pipeline
self.query_pipeline = query_pipeline
self.index_params = index_params
self.query_params = query_params
def search(
self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], top_k: int, score_function: str, **kwargs
) -> Dict[str, Dict[str, float]]:
with tempfile.TemporaryDirectory() as temp_dir:
file_paths = []
metas = []
for id, doc in corpus.items():
file_path = f"{temp_dir}/{id}"
with open(file_path, "w") as f:
f.write(doc["text"])
file_paths.append(file_path)
metas.append({"id": id, "name": doc.get("title", None)})
logger.info(f"indexing {len(corpus)} documents...")
self.index_pipeline.run(file_paths=file_paths, meta=metas, params=self.index_params)
logger.info(f"indexing finished.")
# adjust query_params to ensure top_k is retrieved
query_params = copy.deepcopy(self.query_params)
query_params["top_k"] = top_k
results = {}
for q_id, query in tqdm(queries.items(), total=len(queries)):
res = self.query_pipeline.run(query=query, params=query_params)
docs = res["documents"]
query_results = {doc.meta["id"]: doc.score for doc in docs}
results[q_id] = query_results
return results

View File

@ -174,6 +174,8 @@ ray =
aiorwlock>=1.3.0,<2
colab =
grpcio==1.43.0
beir =
beir
dev =
# Type check
mypy
@ -197,9 +199,9 @@ dev =
test =
farm-haystack[docstores,crawler,preprocessing,ocr,ray,dev]
all =
farm-haystack[docstores,crawler,preprocessing,ocr,ray,dev,onnx]
farm-haystack[docstores,crawler,preprocessing,ocr,ray,dev,onnx,beir]
all-gpu =
farm-haystack[docstores-gpu,crawler,preprocessing,ocr,ray,dev,onnx-gpu]
farm-haystack[docstores-gpu,crawler,preprocessing,ocr,ray,dev,onnx-gpu,beir]
[tool:pytest]

View File

@ -216,6 +216,9 @@ class MockDocumentStore(BaseDocumentStore):
def write_labels(self, *a, **k):
pass
def delete_index(self, *a, **k):
pass
class MockRetriever(BaseRetriever):
outgoing_edges = 1