mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
feat: Adds all_terms_must_match parameter to BM25Retriever at runtime (#3627)
* Adds all_terms_must_match implementation and tests * Adds all_terms_must_match as Optional Signed-off-by: Unai Garay <unaigaraymaestre@gmail.com> * Avoid mypy error and follow pattern checking var is None * Mypy works ok on this file now * added mypy ignores to BaseRetriever * ignoring all overrides for this file * Updates sparse retriever `all_terms_must_match` docstring Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Updates sparse retriever `all_terms_must_match` docstring Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Updates sparse retriever `all_terms_must_match` docstring Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Updates sparse retrieve_batch `all_terms_must_match` docstring Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Updates sparse retrieve_batch `all_terms_must_match` docstring Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * Updates sparse retrieve_batch `all_terms_must_match` docstring Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * marked elasticsearch Signed-off-by: Unai Garay <unaigaraymaestre@gmail.com> Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com> Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
This commit is contained in:
parent
c1c1c97bb2
commit
77cea8b140
@ -1,3 +1,4 @@
|
||||
# mypy: disable-error-code=override
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import logging
|
||||
@ -118,6 +119,7 @@ class BM25Retriever(BaseRetriever):
|
||||
query: str,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
all_terms_must_match: Optional[bool] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
@ -194,6 +196,10 @@ class BM25Retriever(BaseRetriever):
|
||||
}
|
||||
```
|
||||
:param top_k: How many documents to return per query.
|
||||
:param all_terms_must_match: Whether all terms of the query must match the document.
|
||||
When set to `True`, the Retriever returns only documents that contain all query terms (that means the AND operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy AND fish AND restaurant").
|
||||
When set to `False`, the Retriever returns documents containing at least one query term (this means the OR operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy OR fish OR restaurant").
|
||||
Defaults to `None`. If you set a value for this parameter, it overwrites self.all_terms_must_match at runtime.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
||||
@ -216,12 +222,14 @@ class BM25Retriever(BaseRetriever):
|
||||
index = document_store.index
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
if all_terms_must_match is None:
|
||||
all_terms_must_match = self.all_terms_must_match
|
||||
|
||||
documents = document_store.query(
|
||||
query=query,
|
||||
filters=filters,
|
||||
top_k=top_k,
|
||||
all_terms_must_match=self.all_terms_must_match,
|
||||
all_terms_must_match=all_terms_must_match,
|
||||
custom_query=self.custom_query,
|
||||
index=index,
|
||||
headers=headers,
|
||||
@ -239,6 +247,7 @@ class BM25Retriever(BaseRetriever):
|
||||
]
|
||||
] = None,
|
||||
top_k: Optional[int] = None,
|
||||
all_terms_must_match: Optional[bool] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
@ -318,6 +327,10 @@ class BM25Retriever(BaseRetriever):
|
||||
}
|
||||
```
|
||||
:param top_k: How many documents to return per query.
|
||||
:param all_terms_must_match: Whether all terms of the query must match the document.
|
||||
When set to `True`, the Retriever returns only documents that contain all query terms (that means the AND operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy AND fish AND restaurant").
|
||||
When set to `False`, the Retriever returns documents containing at least one query term (this means the OR operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy OR fish OR restaurant").).
|
||||
Defaults to `None`. If you set a value for this parameter, it overwrites self.all_terms_must_match at runtime.
|
||||
:param index: The name of the index in the DocumentStore from which to retrieve documents
|
||||
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
|
||||
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
|
||||
@ -342,12 +355,14 @@ class BM25Retriever(BaseRetriever):
|
||||
index = document_store.index
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
if all_terms_must_match is None:
|
||||
all_terms_must_match = self.all_terms_must_match
|
||||
|
||||
documents = document_store.query_batch(
|
||||
queries=queries,
|
||||
filters=filters,
|
||||
top_k=top_k,
|
||||
all_terms_must_match=self.all_terms_must_match,
|
||||
all_terms_must_match=all_terms_must_match,
|
||||
custom_query=self.custom_query,
|
||||
index=index,
|
||||
headers=headers,
|
||||
|
||||
@ -655,6 +655,47 @@ def test_elasticsearch_all_terms_must_match():
|
||||
doc_store.delete_index(index)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_bm25retriever_all_terms_must_match():
|
||||
index = "all_terms_must_match"
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index=index, ignore=[404])
|
||||
documents = [
|
||||
{
|
||||
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
|
||||
"meta": {"content_type": "text"},
|
||||
"id": "1",
|
||||
},
|
||||
{
|
||||
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
|
||||
"meta": {"content_type": "text"},
|
||||
"id": "2",
|
||||
},
|
||||
{
|
||||
"content": "Green tea also has small amounts of minerals that can benefit your health.",
|
||||
"meta": {"content_type": "text"},
|
||||
"id": "3",
|
||||
},
|
||||
{
|
||||
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
|
||||
"meta": {"content_type": "text"},
|
||||
"id": "4",
|
||||
},
|
||||
]
|
||||
doc_store = ElasticsearchDocumentStore(index=index)
|
||||
doc_store.write_documents(documents)
|
||||
retriever = BM25Retriever(document_store=doc_store)
|
||||
results_wo_all_terms_must_match = retriever.retrieve(query="drink green tea")
|
||||
assert len(results_wo_all_terms_must_match) == 4
|
||||
retriever = BM25Retriever(document_store=doc_store, all_terms_must_match=True)
|
||||
results_w_all_terms_must_match = retriever.retrieve(query="drink green tea")
|
||||
assert len(results_w_all_terms_must_match) == 1
|
||||
retriever = BM25Retriever(document_store=doc_store)
|
||||
results_w_all_terms_must_match = retriever.retrieve(query="drink green tea", all_terms_must_match=True)
|
||||
assert len(results_w_all_terms_must_match) == 1
|
||||
doc_store.delete_index(index)
|
||||
|
||||
|
||||
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format(caplog):
|
||||
document_store = InMemoryDocumentStore()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user