mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
Fix ElasticsearchDocumentStore.query_by_embedding() (#823)
This commit is contained in:
parent
8adf5b4737
commit
4059805d89
@ -497,15 +497,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
Return all documents in a specific index in the document store
|
||||
"""
|
||||
body = {
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": {
|
||||
"match_all": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
body: dict = {"query": {"bool": {}}}
|
||||
|
||||
if filters:
|
||||
filter_clause = []
|
||||
@ -640,13 +632,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"query": self._get_vector_similarity_query(query_emb, top_k)
|
||||
}
|
||||
if filters:
|
||||
filter_clause = []
|
||||
for key, values in filters.items():
|
||||
if type(values) != list:
|
||||
raise ValueError(f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
|
||||
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
|
||||
body["query"]["script_score"]["query"] = {
|
||||
"terms": filters
|
||||
}
|
||||
filter_clause.append(
|
||||
{
|
||||
"terms": {key: values}
|
||||
}
|
||||
)
|
||||
body["query"]["script_score"]["query"] = {"bool": {"filter": filter_clause}}
|
||||
|
||||
excluded_meta_data: Optional[list] = None
|
||||
|
||||
|
||||
@ -244,7 +244,8 @@ def get_retriever(retriever_type, document_store):
|
||||
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
use_gpu=False, embed_title=True)
|
||||
elif retriever_type == "tfidf":
|
||||
return TfidfRetriever(document_store=document_store)
|
||||
retriever = TfidfRetriever(document_store=document_store)
|
||||
retriever.fit()
|
||||
elif retriever_type == "embedding":
|
||||
retriever = EmbeddingRetriever(
|
||||
document_store=document_store,
|
||||
|
||||
@ -212,10 +212,13 @@ def test_update_embeddings(document_store, retriever):
|
||||
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_1",
|
||||
filters={"meta_field": ["value_0", "value_6"]},
|
||||
filters={"meta_field": ["value_0"]},
|
||||
return_embedding=True,
|
||||
)
|
||||
np.testing.assert_array_equal(documents[0].embedding, documents[1].embedding)
|
||||
assert len(documents) == 2
|
||||
for doc in documents:
|
||||
assert doc.meta["meta_field"] == "value_0"
|
||||
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding)
|
||||
|
||||
documents = document_store.get_all_documents(
|
||||
index="haystack_test_1",
|
||||
|
||||
@ -1,24 +0,0 @@
|
||||
from haystack import Document
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["es_filter_only"], indirect=True)
|
||||
def test_dummy_retriever(retriever_with_docs, document_store_with_docs):
|
||||
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=1)
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Carla and I live in Berlin"
|
||||
assert result[0].meta["name"] == "filename1"
|
||||
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename1"]}, top_k=5)
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Carla and I live in Berlin"
|
||||
assert result[0].meta["name"] == "filename1"
|
||||
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Christelle and I live in Paris"
|
||||
assert result[0].meta["name"] == "filename3"
|
||||
|
||||
@ -1,85 +0,0 @@
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_elasticsearch_retrieval(retriever_with_docs, document_store_with_docs):
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 3
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", [("elasticsearch")], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_elasticsearch_retrieval_filters(retriever_with_docs, document_store_with_docs):
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name": ["filename1"]})
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 1
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["not_existing_value"]})
|
||||
assert len(res) == 0
|
||||
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "not_existing_field": ["not_existing_value"]})
|
||||
assert len(res) == 0
|
||||
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field": ["test1","test2"]})
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 1
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?", filters={"name":["filename1"], "meta_field":["test2"]})
|
||||
assert len(res) == 0
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_elasticsearch_custom_query(elasticsearch_fixture):
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test_custom', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test_custom", text_field="custom_text_field",
|
||||
embedding_field="custom_embedding_field")
|
||||
documents = [
|
||||
{"text": "test_1", "meta": {"year": "2019"}},
|
||||
{"text": "test_2", "meta": {"year": "2020"}},
|
||||
{"text": "test_3", "meta": {"year": "2021"}},
|
||||
{"text": "test_4", "meta": {"year": "2021"}},
|
||||
{"text": "test_5", "meta": {"year": "2021"}},
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
|
||||
# test custom "terms" query
|
||||
retriever = ElasticsearchRetriever(
|
||||
document_store=document_store,
|
||||
custom_query="""
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["text"]}}],
|
||||
"filter": [{"terms": {"year": ${years}}}]}}}"""
|
||||
)
|
||||
results = retriever.run(query="test", filters={"years": ["2020", "2021"]})[0]["documents"]
|
||||
assert len(results) == 4
|
||||
|
||||
# test custom "term" query
|
||||
retriever = ElasticsearchRetriever(
|
||||
document_store=document_store,
|
||||
custom_query="""
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["text"]}}],
|
||||
"filter": [{"term": {"year": ${years}}}]}}}"""
|
||||
)
|
||||
results = retriever.run(query="test", filters={"years": "2021"})[0]["documents"]
|
||||
assert len(results) == 3
|
||||
@ -1,35 +0,0 @@
|
||||
import pytest
|
||||
from haystack import Finder
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
def test_embedding_retriever(retriever, document_store):
|
||||
|
||||
documents = [
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'How to test this library?', 'question': 'How to test this library?'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
{'text': 'By running tox in the command line!', 'meta': {'name': 'blah blah blah', 'question': 'blah blah blah'}},
|
||||
]
|
||||
|
||||
embedded = []
|
||||
for doc in documents:
|
||||
doc['embedding'] = retriever.embed([doc['meta']['question']])[0]
|
||||
embedded.append(doc)
|
||||
|
||||
document_store.write_documents(embedded)
|
||||
|
||||
finder = Finder(reader=None, retriever=retriever)
|
||||
prediction = finder.get_answers_via_similar_questions(question="How to test this?", top_k_retriever=1)
|
||||
|
||||
assert len(prediction.get('answers', [])) == 1
|
||||
@ -1,21 +1,122 @@
|
||||
import pytest
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
from haystack import Document
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
from haystack.document_store.milvus import MilvusDocumentStore
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever, ElasticsearchFilterOnlyRetriever, TfidfRetriever
|
||||
from transformers import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize(
|
||||
"retriever_with_docs,document_store_with_docs",
|
||||
[
|
||||
("dpr", "elasticsearch"),
|
||||
("dpr", "faiss"),
|
||||
("dpr", "memory"),
|
||||
("dpr", "milvus"),
|
||||
("embedding", "elasticsearch"),
|
||||
("embedding", "faiss"),
|
||||
("embedding", "memory"),
|
||||
("embedding", "milvus"),
|
||||
("elasticsearch", "elasticsearch"),
|
||||
("es_filter_only", "elasticsearch"),
|
||||
("tfidf", "memory"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
def test_retrieval(retriever_with_docs, document_store_with_docs):
|
||||
if not isinstance(retriever_with_docs, (ElasticsearchRetriever, ElasticsearchFilterOnlyRetriever, TfidfRetriever)):
|
||||
document_store_with_docs.update_embeddings(retriever_with_docs)
|
||||
|
||||
# test without filters
|
||||
res = retriever_with_docs.retrieve(query="Who lives in Berlin?")
|
||||
assert res[0].text == "My name is Carla and I live in Berlin"
|
||||
assert len(res) == 3
|
||||
assert res[0].meta["name"] == "filename1"
|
||||
|
||||
# test with filters
|
||||
if not isinstance(document_store_with_docs, (FAISSDocumentStore, MilvusDocumentStore)) and not isinstance(
|
||||
retriever_with_docs, TfidfRetriever
|
||||
):
|
||||
# single filter
|
||||
result = retriever_with_docs.retrieve(query="godzilla", filters={"name": ["filename3"]}, top_k=5)
|
||||
assert len(result) == 1
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].text == "My name is Christelle and I live in Paris"
|
||||
assert result[0].meta["name"] == "filename3"
|
||||
|
||||
# multiple filters
|
||||
result = retriever_with_docs.retrieve(
|
||||
query="godzilla", filters={"name": ["filename2"], "meta_field": ["test2", "test3"]}, top_k=5
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert type(result[0]) == Document
|
||||
assert result[0].meta["name"] == "filename2"
|
||||
|
||||
result = retriever_with_docs.retrieve(
|
||||
query="godzilla", filters={"name": ["filename1"], "meta_field": ["test2", "test3"]}, top_k=5
|
||||
)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_elasticsearch_custom_query(elasticsearch_fixture):
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index="haystack_test_custom", ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
index="haystack_test_custom", text_field="custom_text_field", embedding_field="custom_embedding_field"
|
||||
)
|
||||
documents = [
|
||||
{"text": "test_1", "meta": {"year": "2019"}},
|
||||
{"text": "test_2", "meta": {"year": "2020"}},
|
||||
{"text": "test_3", "meta": {"year": "2021"}},
|
||||
{"text": "test_4", "meta": {"year": "2021"}},
|
||||
{"text": "test_5", "meta": {"year": "2021"}},
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
|
||||
# test custom "terms" query
|
||||
retriever = ElasticsearchRetriever(
|
||||
document_store=document_store,
|
||||
custom_query="""
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["text"]}}],
|
||||
"filter": [{"terms": {"year": ${years}}}]}}}""",
|
||||
)
|
||||
results = retriever.run(query="test", filters={"years": ["2020", "2021"]})[0]["documents"]
|
||||
assert len(results) == 4
|
||||
|
||||
# test custom "term" query
|
||||
retriever = ElasticsearchRetriever(
|
||||
document_store=document_store,
|
||||
custom_query="""
|
||||
{
|
||||
"size": 10,
|
||||
"query": {
|
||||
"bool": {
|
||||
"should": [{
|
||||
"multi_match": {"query": ${query}, "type": "most_fields", "fields": ["text"]}}],
|
||||
"filter": [{"term": {"year": ${years}}}]}}}""",
|
||||
)
|
||||
results = retriever.run(query="test", filters={"years": "2021"})[0]["documents"]
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("return_embedding", [True, False])
|
||||
def test_dpr_retrieval(document_store, retriever, return_embedding):
|
||||
|
||||
def test_dpr_embedding(document_store, retriever):
|
||||
documents = [
|
||||
Document(
|
||||
text="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""",
|
||||
@ -40,55 +141,39 @@ def test_dpr_retrieval(document_store, retriever, return_embedding):
|
||||
text="""The title of the episode refers to the Great Sept of Baelor, the main religious building in King's Landing, where the episode's pivotal scene takes place. In the world created by George R. R. Martin""",
|
||||
meta={},
|
||||
id="5",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
document_store.return_embedding = return_embedding
|
||||
document_store.return_embedding = True
|
||||
document_store.write_documents(documents)
|
||||
document_store.update_embeddings(retriever=retriever)
|
||||
time.sleep(1)
|
||||
|
||||
if return_embedding is True:
|
||||
doc_1 = document_store.get_document_by_id("1")
|
||||
assert (len(doc_1.embedding) == 768)
|
||||
assert (abs(doc_1.embedding[0] - (-0.3063)) < 0.001)
|
||||
doc_2 = document_store.get_document_by_id("2")
|
||||
assert (abs(doc_2.embedding[0] - (-0.3914)) < 0.001)
|
||||
doc_3 = document_store.get_document_by_id("3")
|
||||
assert (abs(doc_3.embedding[0] - (-0.2470)) < 0.001)
|
||||
doc_4 = document_store.get_document_by_id("4")
|
||||
assert (abs(doc_4.embedding[0] - (-0.0802)) < 0.001)
|
||||
doc_5 = document_store.get_document_by_id("5")
|
||||
assert (abs(doc_5.embedding[0] - (-0.0551)) < 0.001)
|
||||
|
||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?")
|
||||
|
||||
assert res[0].meta["name"] == "1"
|
||||
|
||||
# test embedding
|
||||
if return_embedding is True:
|
||||
assert res[0].embedding is not None
|
||||
else:
|
||||
assert res[0].embedding is None
|
||||
|
||||
# test filtering
|
||||
if not isinstance(document_store, FAISSDocumentStore) and not isinstance(document_store, MilvusDocumentStore):
|
||||
res = retriever.retrieve(query="Which philosopher attacked Schopenhauer?", filters={"name": ["0", "2"]})
|
||||
assert len(res) == 2
|
||||
for r in res:
|
||||
assert r.meta["name"] in ["0", "2"]
|
||||
doc_1 = document_store.get_document_by_id("1")
|
||||
assert len(doc_1.embedding) == 768
|
||||
assert abs(doc_1.embedding[0] - (-0.3063)) < 0.001
|
||||
doc_2 = document_store.get_document_by_id("2")
|
||||
assert abs(doc_2.embedding[0] - (-0.3914)) < 0.001
|
||||
doc_3 = document_store.get_document_by_id("3")
|
||||
assert abs(doc_3.embedding[0] - (-0.2470)) < 0.001
|
||||
doc_4 = document_store.get_document_by_id("4")
|
||||
assert abs(doc_4.embedding[0] - (-0.0802)) < 0.001
|
||||
doc_5 = document_store.get_document_by_id("5")
|
||||
assert abs(doc_5.embedding[0] - (-0.0551)) < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_dpr_saving_and_loading(retriever, document_store):
|
||||
retriever.save("test_dpr_save")
|
||||
|
||||
def sum_params(model):
|
||||
s = []
|
||||
for p in model.parameters():
|
||||
n = p.cpu().data.numpy()
|
||||
s.append(np.sum(n))
|
||||
return sum(s)
|
||||
|
||||
original_sum_query = sum_params(retriever.query_encoder)
|
||||
original_sum_passage = sum_params(retriever.passage_encoder)
|
||||
del retriever
|
||||
@ -123,4 +208,3 @@ def test_dpr_saving_and_loading(retriever, document_store):
|
||||
assert loaded_retriever.query_tokenizer.vocab_size == 30522
|
||||
assert loaded_retriever.passage_tokenizer.model_max_length == 512
|
||||
assert loaded_retriever.query_tokenizer.model_max_length == 512
|
||||
|
||||
@ -1,20 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
|
||||
def test_tfidf_retriever(document_store, retriever):
|
||||
|
||||
test_docs = [
|
||||
{"id": "26f84672c6d7aaeb8e2cd53e9c62d62d", "name": "testing the finder 1", "text": "godzilla says hello"},
|
||||
{"name": "testing the finder 2", "text": "optimus prime says bye"},
|
||||
{"name": "testing the finder 3", "text": "alien says arghh"}
|
||||
]
|
||||
|
||||
document_store.write_documents(test_docs)
|
||||
|
||||
retriever.fit()
|
||||
doc = retriever.retrieve("godzilla", top_k=1)[0]
|
||||
assert doc.id == "26f84672c6d7aaeb8e2cd53e9c62d62d"
|
||||
assert doc.text == 'godzilla says hello'
|
||||
assert doc.meta == {"name": "testing the finder 1"}
|
||||
Loading…
x
Reference in New Issue
Block a user