feat: Adds all_terms_must_match parameter to BM25Retriever at runtime (#3627)

* Adds all_terms_must_match implementation and tests

* Adds all_terms_must_match as Optional

Signed-off-by: Unai Garay <unaigaraymaestre@gmail.com>

* Avoid mypy error and follow pattern checking var is None

* Mypy works ok on this file now

* added mypy ignores to BaseRetriever

* ignoring all overrides for this file

* Updates sparse retriever `all_terms_must_match` docstring

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Updates sparse retriever `all_terms_must_match` docstring

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Updates sparse retriever `all_terms_must_match` docstring

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Updates sparse retrieve_batch `all_terms_must_match` docstring

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Updates sparse retrieve_batch `all_terms_must_match` docstring

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* Updates sparse retrieve_batch `all_terms_must_match` docstring

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* marked elasticsearch

Signed-off-by: Unai Garay <unaigaraymaestre@gmail.com>
Co-authored-by: Mayank Jobanputra <mayankjobanputra@gmail.com>
Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
This commit is contained in:
Unai Garay Maestre 2022-12-08 12:48:43 +01:00 committed by GitHub
parent c1c1c97bb2
commit 77cea8b140
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 2 deletions

View File

@ -1,3 +1,4 @@
# mypy: disable-error-code=override
from typing import Dict, List, Optional, Union
import logging
@ -118,6 +119,7 @@ class BM25Retriever(BaseRetriever):
query: str,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
top_k: Optional[int] = None,
all_terms_must_match: Optional[bool] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: Optional[bool] = None,
@ -194,6 +196,10 @@ class BM25Retriever(BaseRetriever):
}
```
:param top_k: How many documents to return per query.
:param all_terms_must_match: Whether all terms of the query must match the document.
When set to `True`, the Retriever returns only documents that contain all query terms (that means the AND operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy AND fish AND restaurant").
When set to `False`, the Retriever returns documents containing at least one query term (this means the OR operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy OR fish OR restaurant").
Defaults to `None`. If you set a value for this parameter, it overwrites self.all_terms_must_match at runtime.
:param index: The name of the index in the DocumentStore from which to retrieve documents
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
@ -216,12 +222,14 @@ class BM25Retriever(BaseRetriever):
index = document_store.index
if scale_score is None:
scale_score = self.scale_score
if all_terms_must_match is None:
all_terms_must_match = self.all_terms_must_match
documents = document_store.query(
query=query,
filters=filters,
top_k=top_k,
all_terms_must_match=self.all_terms_must_match,
all_terms_must_match=all_terms_must_match,
custom_query=self.custom_query,
index=index,
headers=headers,
@ -239,6 +247,7 @@ class BM25Retriever(BaseRetriever):
]
] = None,
top_k: Optional[int] = None,
all_terms_must_match: Optional[bool] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
batch_size: Optional[int] = None,
@ -318,6 +327,10 @@ class BM25Retriever(BaseRetriever):
}
```
:param top_k: How many documents to return per query.
:param all_terms_must_match: Whether all terms of the query must match the document.
When set to `True`, the Retriever returns only documents that contain all query terms (that means the AND operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy AND fish AND restaurant").
When set to `False`, the Retriever returns documents containing at least one query term (this means the OR operator is being used implicitly between query terms. For example, the query "cozy fish restaurant" is read as "cozy OR fish OR restaurant").).
Defaults to `None`. If you set a value for this parameter, it overwrites self.all_terms_must_match at runtime.
:param index: The name of the index in the DocumentStore from which to retrieve documents
:param headers: Custom HTTP headers to pass to elasticsearch client (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='})
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
@ -342,12 +355,14 @@ class BM25Retriever(BaseRetriever):
index = document_store.index
if scale_score is None:
scale_score = self.scale_score
if all_terms_must_match is None:
all_terms_must_match = self.all_terms_must_match
documents = document_store.query_batch(
queries=queries,
filters=filters,
top_k=top_k,
all_terms_must_match=self.all_terms_must_match,
all_terms_must_match=all_terms_must_match,
custom_query=self.custom_query,
index=index,
headers=headers,

View File

@ -655,6 +655,47 @@ def test_elasticsearch_all_terms_must_match():
doc_store.delete_index(index)
@pytest.mark.elasticsearch
def test_bm25retriever_all_terms_must_match():
index = "all_terms_must_match"
client = Elasticsearch()
client.indices.delete(index=index, ignore=[404])
documents = [
{
"content": "The green tea plant contains a range of healthy compounds that make it into the final drink",
"meta": {"content_type": "text"},
"id": "1",
},
{
"content": "Green tea contains a catechin called epigallocatechin-3-gallate (EGCG).",
"meta": {"content_type": "text"},
"id": "2",
},
{
"content": "Green tea also has small amounts of minerals that can benefit your health.",
"meta": {"content_type": "text"},
"id": "3",
},
{
"content": "Green tea does more than just keep you alert, it may also help boost brain function.",
"meta": {"content_type": "text"},
"id": "4",
},
]
doc_store = ElasticsearchDocumentStore(index=index)
doc_store.write_documents(documents)
retriever = BM25Retriever(document_store=doc_store)
results_wo_all_terms_must_match = retriever.retrieve(query="drink green tea")
assert len(results_wo_all_terms_must_match) == 4
retriever = BM25Retriever(document_store=doc_store, all_terms_must_match=True)
results_w_all_terms_must_match = retriever.retrieve(query="drink green tea")
assert len(results_w_all_terms_must_match) == 1
retriever = BM25Retriever(document_store=doc_store)
results_w_all_terms_must_match = retriever.retrieve(query="drink green tea", all_terms_must_match=True)
assert len(results_w_all_terms_must_match) == 1
doc_store.delete_index(index)
def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_format(caplog):
document_store = InMemoryDocumentStore()