From ca86cc834d80745d924cf4be6ddaca7562323ddc Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Mon, 21 Mar 2022 19:04:28 +0100 Subject: [PATCH] Integrate BEIR (#2333) * introduce eval_beir() to Pipeline * add beir dependency * Update Documentation & Code Style * top_k_values added + refactoring * Update Documentation & Code Style * enable titles during beir eval * Update Documentation & Code Style * raise HaystackError instead of PipelineError * get rid of forced dedicated index * minor docstring and comment fixes * show warning on default index deletion * Update Documentation & Code Style * add delete_index to MockDocumentStore Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/_src/api/api/document_store.md | 127 +++++++++++++++++++++ docs/_src/api/api/pipelines.md | 59 +++++++++- haystack/document_stores/base.py | 10 ++ haystack/document_stores/deepsetcloud.py | 3 + haystack/document_stores/elasticsearch.py | 19 ++-- haystack/document_stores/faiss.py | 15 +++ haystack/document_stores/memory.py | 14 +++ haystack/document_stores/milvus1.py | 15 +++ haystack/document_stores/milvus2.py | 15 +++ haystack/document_stores/sql.py | 9 ++ haystack/document_stores/weaviate.py | 14 +++ haystack/pipelines/base.py | 133 +++++++++++++++++++++- setup.cfg | 6 +- test/conftest.py | 3 + 14 files changed, 428 insertions(+), 14 deletions(-) diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md index c5422c956..85f7895d6 100644 --- a/docs/_src/api/api/document_store.md +++ b/docs/_src/api/api/document_store.md @@ -250,6 +250,25 @@ When set to None (default) all available eval documents are used. same question might be found in different contexts. - `headers`: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication) + + +#### delete\_index + +```python +@abstractmethod +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### run @@ -1842,6 +1861,24 @@ Example: None + + +#### delete\_index + +```python +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### delete\_labels @@ -2127,6 +2164,24 @@ have their ID in the list). None + + +#### delete\_index + +```python +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### delete\_labels @@ -2371,6 +2426,24 @@ have their ID in the list). None + + +#### delete\_index + +```python +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### query\_by\_embedding @@ -2641,6 +2714,24 @@ have their ID in the list). None + + +#### delete\_index + +```python +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### get\_all\_documents\_generator @@ -2932,6 +3023,24 @@ Example: {"name": ["some", "more"], "category": ["only_one"]} None + + +#### delete\_index + +```python +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### get\_all\_documents\_generator @@ -3565,6 +3674,24 @@ operation. None + + +#### delete\_index + +```python +def delete_index(index: str) +``` + +Delete an existing index. The index including all data will be removed. + +**Arguments**: + +- `index`: The name of the index to delete. + +**Returns**: + +None + #### delete\_labels diff --git a/docs/_src/api/api/pipelines.md b/docs/_src/api/api/pipelines.md index 700eca06f..0cce6fc5d 100644 --- a/docs/_src/api/api/pipelines.md +++ b/docs/_src/api/api/pipelines.md @@ -413,7 +413,7 @@ Set the component for a node in the Pipeline. #### run ```python -def run(query: Optional[str] = None, file_paths: Optional[List[str]] = None, labels: Optional[MultiLabel] = None, documents: Optional[List[Document]] = None, meta: Optional[dict] = None, params: Optional[dict] = None, debug: Optional[bool] = None) +def run(query: Optional[str] = None, file_paths: Optional[List[str]] = None, labels: Optional[MultiLabel] = None, documents: Optional[List[Document]] = None, meta: Optional[Union[dict, List[dict]]] = None, params: Optional[dict] = None, debug: Optional[bool] = None) ``` Runs the pipeline, one node at a time. @@ -434,6 +434,35 @@ about their execution. By default these include the input parameters they received and the output they generated. All debug information can then be found in the dict returned by this method under the key "_debug" + + +#### eval\_beir + +```python +@classmethod +def eval_beir(cls, index_pipeline: Pipeline, query_pipeline: Pipeline, index_params: dict = {}, query_params: dict = {}, dataset: str = "scifact", dataset_dir: Path = Path("."), top_k_values: List[int] = [1, 3, 5, 10, 100, 1000], keep_index: bool = False) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]] +``` + +Runs information retrieval evaluation of a pipeline using BEIR on a specified BEIR dataset. + +See https://github.com/beir-cellar/beir for more information. + +**Arguments**: + +- `index_pipeline`: The indexing pipeline to use. +- `query_pipeline`: The query pipeline to evaluate. +- `index_params`: The params to use during indexing (see pipeline.run's params). +- `query_params`: The params to use during querying (see pipeline.run's params). +- `dataset`: The BEIR dataset to use. +- `dataset_dir`: The directory to store the dataset to. +- `top_k_values`: The top_k values each metric will be calculated for. +- `keep_index`: Whether to keep the index after evaluation. +If True the index will be kept after beir evaluation. Otherwise it will be deleted immediately afterwards. + Defaults to False. + +Returns a tuple containing the ncdg, map, recall and precision scores. +Each metric is represented by a dictionary containing the scores for each top_k value. + #### eval @@ -835,6 +864,34 @@ def __call__(*args, **kwargs) Ray calls this method which is then re-directed to the corresponding component's run(). + + +## \_HaystackBeirRetrieverAdapter + +```python +class _HaystackBeirRetrieverAdapter() +``` + + + +#### \_\_init\_\_ + +```python +def __init__(index_pipeline: Pipeline, query_pipeline: Pipeline, index_params: dict, query_params: dict) +``` + +Adapter mimicking a BEIR retriever used by BEIR's EvaluateRetrieval class to run BEIR evaluations on Haystack Pipelines. + +This has nothing to do with Haystack's retriever classes. +See https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py. + +**Arguments**: + +- `index_pipeline`: The indexing pipeline to use. +- `query_pipeline`: The query pipeline to evaluate. +- `index_params`: The params to use during indexing (see pipeline.run's params). +- `query_params`: The params to use during querying (see pipeline.run's params). + # Module standard\_pipelines diff --git a/haystack/document_stores/base.py b/haystack/document_stores/base.py index e8e7d2f97..6b6f810e5 100644 --- a/haystack/document_stores/base.py +++ b/haystack/document_stores/base.py @@ -512,6 +512,16 @@ class BaseDocumentStore(BaseComponent): ): pass + @abstractmethod + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + pass + @abstractmethod def _create_document_field_map(self) -> Dict: pass diff --git a/haystack/document_stores/deepsetcloud.py b/haystack/document_stores/deepsetcloud.py index fe4506691..5fdce4f28 100644 --- a/haystack/document_stores/deepsetcloud.py +++ b/haystack/document_stores/deepsetcloud.py @@ -485,3 +485,6 @@ class DeepsetCloudDocumentStore(KeywordDocumentStore): headers: Optional[Dict[str, str]] = None, ): raise NotImplementedError("DeepsetCloudDocumentStore currently does not support labels.") + + def delete_index(self, index: str): + raise NotImplementedError("DeepsetCloudDocumentStore currently does not support deleting indexes.") diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index 0eeb6f611..bd4dc327f 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -318,12 +318,11 @@ class ElasticsearchDocumentStore(KeywordDocumentStore): if self.search_fields: for search_field in self.search_fields: if search_field in mapping["properties"] and mapping["properties"][search_field]["type"] != "text": - host_data = self.client.transport.hosts[0] raise Exception( f"The search_field '{search_field}' of index '{index_name}' with type '{mapping['properties'][search_field]['type']}' " - f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields. " - f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack." - f"In this case deleting the index with `curl -X DELETE \"{host_data['host']}:{host_data['port']}/{index_name}\"` will fix your environment. " + f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields or use another index. " + f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack. " + f'In this case deleting the index with `delete_index(index="{index_name}")` will fix your environment. ' f"Note, that all data stored in the index will be lost!" ) if self.embedding_field: @@ -1571,6 +1570,11 @@ class ElasticsearchDocumentStore(KeywordDocumentStore): :param index: The name of the index to delete. :return: None """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) self.client.indices.delete(index=index, ignore=[400, 404]) logger.debug(f"deleted elasticsearch index {index}") @@ -1790,12 +1794,11 @@ class OpenSearchDocumentStore(ElasticsearchDocumentStore): search_field in mappings["properties"] and mappings["properties"][search_field]["type"] != "text" ): - host_data = self.client.transport.hosts[0] raise Exception( f"The search_field '{search_field}' of index '{index_name}' with type '{mappings['properties'][search_field]['type']}' " - f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields. " - f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack." - f"In this case deleting the index with `curl -X DELETE \"{host_data['host']}:{host_data['port']}/{index_name}\"` will fix your environment. " + f"does not have the right type 'text' to be queried in fulltext search. Please use only 'text' type properties as search_fields or use another index. " + f"This error might occur if you are trying to use haystack 1.0 and above with an existing elasticsearch index created with a previous version of haystack. " + f'In this case deleting the index with `delete_index(index="{index_name}")` will fix your environment. ' f"Note, that all data stored in the index will be lost!" ) diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index 0d62446a1..0987c4437 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -525,6 +525,21 @@ class FAISSDocumentStore(SQLDocumentStore): super().delete_documents(index=index, ids=ids, filters=filters) + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) + del self.faiss_indexes[index] + super().delete_index(index) + def query_by_embedding( self, query_emb: np.ndarray, diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index f877f0e7b..6f8403cee 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -743,6 +743,20 @@ class InMemoryDocumentStore(BaseDocumentStore): for doc in docs_to_delete: del self.indexes[index][doc.id] + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) + del self.indexes[index] + def delete_labels( self, index: Optional[str] = None, diff --git a/haystack/document_stores/milvus1.py b/haystack/document_stores/milvus1.py index fdeb648d2..e5daccac1 100644 --- a/haystack/document_stores/milvus1.py +++ b/haystack/document_stores/milvus1.py @@ -483,6 +483,21 @@ class Milvus1DocumentStore(SQLDocumentStore): # Delete from SQL at the end to allow the above .get_all_documents() to work properly super().delete_documents(index=index, ids=ids, filters=filters) + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) + self.milvus_server.drop_collection(index) + super().delete_index(index) + def get_all_documents_generator( self, index: Optional[str] = None, diff --git a/haystack/document_stores/milvus2.py b/haystack/document_stores/milvus2.py index 3120e92d1..b6f86dd2d 100644 --- a/haystack/document_stores/milvus2.py +++ b/haystack/document_stores/milvus2.py @@ -474,6 +474,21 @@ class Milvus2DocumentStore(SQLDocumentStore): index = index or self.index super().delete_documents(index=index, filters=filters, ids=ids) + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) + utility.drop_collection(collection_name=index) + super().delete_index(index) + def get_all_documents_generator( self, index: Optional[str] = None, diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 0cca02ea0..9c71b0127 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -658,6 +658,15 @@ class SQLDocumentStore(BaseDocumentStore): self.session.commit() + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + self.delete_documents(index) + def delete_labels( self, index: Optional[str] = None, diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index d096399f6..cfdb91c00 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -1146,6 +1146,20 @@ class WeaviateDocumentStore(BaseDocumentStore): for doc in docs_to_delete: self.weaviate_client.data_object.delete(doc.id) + def delete_index(self, index: str): + """ + Delete an existing index. The index including all data will be removed. + + :param index: The name of the index to delete. + :return: None + """ + if index == self.index: + logger.warning( + f"Deletion of default index '{index}' detected. " + f"If you plan to use this index again, please reinstantiate '{self.__class__.__name__}' in order to avoid side-effects." + ) + self.weaviate_client.schema.delete_class(index) + def delete_labels(self): """ Implemented to respect BaseDocumentStore's contract. diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index c45655f45..35cd826d9 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -1,6 +1,7 @@ from __future__ import annotations from os import pipe -from typing import Dict, List, Optional, Any, Set +import tempfile +from typing import Dict, List, Optional, Any, Set, Tuple, Union import copy import json @@ -16,6 +17,7 @@ from jsonschema import Draft7Validator from jsonschema.exceptions import ValidationError from jsonschema import _utils as jsonschema_utils from pandas.core.frame import DataFrame +from tqdm import tqdm from transformers import pipelines import yaml from networkx import DiGraph @@ -45,7 +47,7 @@ except: from haystack import __version__ from haystack.schema import EvaluationResult, MultiLabel, Document -from haystack.errors import PipelineError, PipelineConfigError +from haystack.errors import HaystackError, PipelineError, PipelineConfigError from haystack.nodes.base import BaseComponent from haystack.nodes.retriever.base import BaseRetriever from haystack.document_stores.base import BaseDocumentStore @@ -576,7 +578,7 @@ class Pipeline(BasePipeline): file_paths: Optional[List[str]] = None, labels: Optional[MultiLabel] = None, documents: Optional[List[Document]] = None, - meta: Optional[dict] = None, + meta: Optional[Union[dict, List[dict]]] = None, params: Optional[dict] = None, debug: Optional[bool] = None, ): @@ -691,6 +693,83 @@ class Pipeline(BasePipeline): i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors return node_output + @classmethod + def eval_beir( + cls, + index_pipeline: Pipeline, + query_pipeline: Pipeline, + index_params: dict = {}, + query_params: dict = {}, + dataset: str = "scifact", + dataset_dir: Path = Path("."), + top_k_values: List[int] = [1, 3, 5, 10, 100, 1000], + keep_index: bool = False, + ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]: + """ + Runs information retrieval evaluation of a pipeline using BEIR on a specified BEIR dataset. + See https://github.com/beir-cellar/beir for more information. + + :param index_pipeline: The indexing pipeline to use. + :param query_pipeline: The query pipeline to evaluate. + :param index_params: The params to use during indexing (see pipeline.run's params). + :param query_params: The params to use during querying (see pipeline.run's params). + :param dataset: The BEIR dataset to use. + :param dataset_dir: The directory to store the dataset to. + :param top_k_values: The top_k values each metric will be calculated for. + :param keep_index: Whether to keep the index after evaluation. + If True the index will be kept after beir evaluation. Otherwise it will be deleted immediately afterwards. + Defaults to False. + + Returns a tuple containing the ncdg, map, recall and precision scores. + Each metric is represented by a dictionary containing the scores for each top_k value. + """ + try: + from beir import util + from beir.datasets.data_loader import GenericDataLoader + from beir.retrieval.evaluation import EvaluateRetrieval + except ModuleNotFoundError as e: + raise HaystackError("beir is not installed. Please run `pip install farm-haystack[beir]`...") from e + + url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip" + data_path = util.download_and_unzip(url, dataset_dir) + logger.info(f"Dataset downloaded here: {data_path}") + corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") # or split = "train" or "dev" + + # check index before eval + document_store = index_pipeline.get_document_store() + if document_store is not None: + if document_store.get_document_count() > 0: + raise HaystackError(f"Index '{document_store.index}' is not empty. Please provide an empty index.") + + if hasattr(document_store, "search_fields"): + search_fields = getattr(document_store, "search_fields") + if "name" not in search_fields: + logger.warning( + "Field 'name' is not part of your DocumentStore's search_fields. Titles won't be searchable. " + "Please set search_fields appropriately." + ) + + haystack_retriever = _HaystackBeirRetrieverAdapter( + index_pipeline=index_pipeline, + query_pipeline=query_pipeline, + index_params=index_params, + query_params=query_params, + ) + retriever = EvaluateRetrieval(haystack_retriever, k_values=top_k_values) + + # Retrieve results (format of results is identical to qrels) + results = retriever.retrieve(corpus, queries) + + # Clean up document store + if not keep_index and document_store is not None and document_store.index is not None: + logger.info(f"Cleaning up: deleting index '{document_store.index}'...") + document_store.delete_index(document_store.index) + + # Evaluate your retrieval using NDCG@k, MAP@K ... + logger.info(f"Retriever evaluation for k in: {retriever.k_values}") + ndcg, map_, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) + return ndcg, map_, recall, precision + @send_event def eval( self, @@ -1561,3 +1640,51 @@ class _RayDeploymentWrapper: Ray calls this method which is then re-directed to the corresponding component's run(). """ return self.node._dispatch_run(*args, **kwargs) + + +class _HaystackBeirRetrieverAdapter: + def __init__(self, index_pipeline: Pipeline, query_pipeline: Pipeline, index_params: dict, query_params: dict): + """ + Adapter mimicking a BEIR retriever used by BEIR's EvaluateRetrieval class to run BEIR evaluations on Haystack Pipelines. + This has nothing to do with Haystack's retriever classes. + See https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py. + + :param index_pipeline: The indexing pipeline to use. + :param query_pipeline: The query pipeline to evaluate. + :param index_params: The params to use during indexing (see pipeline.run's params). + :param query_params: The params to use during querying (see pipeline.run's params). + """ + self.index_pipeline = index_pipeline + self.query_pipeline = query_pipeline + self.index_params = index_params + self.query_params = query_params + + def search( + self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], top_k: int, score_function: str, **kwargs + ) -> Dict[str, Dict[str, float]]: + with tempfile.TemporaryDirectory() as temp_dir: + file_paths = [] + metas = [] + for id, doc in corpus.items(): + file_path = f"{temp_dir}/{id}" + with open(file_path, "w") as f: + f.write(doc["text"]) + file_paths.append(file_path) + metas.append({"id": id, "name": doc.get("title", None)}) + + logger.info(f"indexing {len(corpus)} documents...") + self.index_pipeline.run(file_paths=file_paths, meta=metas, params=self.index_params) + logger.info(f"indexing finished.") + + # adjust query_params to ensure top_k is retrieved + query_params = copy.deepcopy(self.query_params) + query_params["top_k"] = top_k + + results = {} + for q_id, query in tqdm(queries.items(), total=len(queries)): + res = self.query_pipeline.run(query=query, params=query_params) + docs = res["documents"] + query_results = {doc.meta["id"]: doc.score for doc in docs} + results[q_id] = query_results + + return results diff --git a/setup.cfg b/setup.cfg index 3fcc41fdb..9a8bb6b8c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -174,6 +174,8 @@ ray = aiorwlock>=1.3.0,<2 colab = grpcio==1.43.0 +beir = + beir dev = # Type check mypy @@ -197,9 +199,9 @@ dev = test = farm-haystack[docstores,crawler,preprocessing,ocr,ray,dev] all = - farm-haystack[docstores,crawler,preprocessing,ocr,ray,dev,onnx] + farm-haystack[docstores,crawler,preprocessing,ocr,ray,dev,onnx,beir] all-gpu = - farm-haystack[docstores-gpu,crawler,preprocessing,ocr,ray,dev,onnx-gpu] + farm-haystack[docstores-gpu,crawler,preprocessing,ocr,ray,dev,onnx-gpu,beir] [tool:pytest] diff --git a/test/conftest.py b/test/conftest.py index 24e3039e2..ec9980f29 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -216,6 +216,9 @@ class MockDocumentStore(BaseDocumentStore): def write_labels(self, *a, **k): pass + def delete_index(self, *a, **k): + pass + class MockRetriever(BaseRetriever): outgoing_edges = 1