2020-08-10 05:34:39 -04:00
|
|
|
import numpy as np
|
2020-07-16 15:34:55 +02:00
|
|
|
import pytest
|
2020-08-10 05:34:39 -04:00
|
|
|
from elasticsearch import Elasticsearch
|
2020-07-16 15:34:55 +02:00
|
|
|
|
2020-09-16 18:33:23 +02:00
|
|
|
from haystack import Document, Label
|
|
|
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
|
|
|
from haystack.document_store.faiss import FAISSDocumentStore
|
2020-07-14 09:53:31 +02:00
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_without_filters(document_store_with_docs):
|
2020-07-14 09:53:31 +02:00
|
|
|
documents = document_store_with_docs.get_all_documents()
|
|
|
|
assert all(isinstance(d, Document) for d in documents)
|
|
|
|
assert len(documents) == 3
|
|
|
|
assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"}
|
|
|
|
assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3"}
|
2020-08-04 14:24:12 +02:00
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-09-18 12:22:52 +02:00
|
|
|
def test_get_all_document_filter_duplicate_value(document_store):
|
|
|
|
documents = [
|
|
|
|
Document(
|
|
|
|
text="Doc1",
|
|
|
|
meta={"f1": "0"}
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
text="Doc1",
|
|
|
|
meta={"f1": "1", "vector_id": "0"}
|
|
|
|
),
|
|
|
|
Document(
|
|
|
|
text="Doc2",
|
|
|
|
meta={"f3": "0"}
|
|
|
|
)
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
documents = document_store.get_all_documents(filters={"f1": ["1"]})
|
|
|
|
assert documents[0].text == "Doc1"
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert {d.meta["vector_id"] for d in documents} == {"0"}
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_with_correct_filters(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]})
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert documents[0].meta["name"] == "filename2"
|
|
|
|
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test3"]})
|
|
|
|
assert len(documents) == 2
|
|
|
|
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
|
|
|
|
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"incorrect_meta_field": ["test2"]})
|
|
|
|
assert len(documents) == 0
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["incorrect_value"]})
|
|
|
|
assert len(documents) == 0
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_documents_by_id(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents()
|
2020-07-14 09:53:31 +02:00
|
|
|
doc = document_store_with_docs.get_document_by_id(documents[0].id)
|
|
|
|
assert doc.id == documents[0].id
|
|
|
|
assert doc.text == documents[0].text
|
2020-07-16 15:34:55 +02:00
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-10-22 12:42:13 +02:00
|
|
|
def test_get_document_count(document_store):
|
|
|
|
documents = [
|
|
|
|
{"text": "text1", "id": "1", "meta_field_for_count": "a"},
|
|
|
|
{"text": "text2", "id": "2", "meta_field_for_count": "b"},
|
|
|
|
{"text": "text3", "id": "3", "meta_field_for_count": "b"},
|
|
|
|
{"text": "text4", "id": "4", "meta_field_for_count": "b"},
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
assert document_store.get_document_count() == 4
|
|
|
|
assert document_store.get_document_count(filters={"meta_field_for_count": ["a"]}) == 1
|
|
|
|
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-13 12:25:32 +02:00
|
|
|
def test_write_document_meta(document_store):
|
|
|
|
documents = [
|
|
|
|
{"text": "dict_without_meta", "id": "1"},
|
|
|
|
{"text": "dict_with_meta", "meta_field": "test2", "name": "filename2", "id": "2"},
|
|
|
|
Document(text="document_object_without_meta", id="3"),
|
|
|
|
Document(text="document_object_with_meta", meta={"meta_field": "test4", "name": "filename3"}, id="4"),
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
documents_in_store = document_store.get_all_documents()
|
|
|
|
assert len(documents_in_store) == 4
|
|
|
|
|
|
|
|
assert not document_store.get_document_by_id("1").meta
|
|
|
|
assert document_store.get_document_by_id("2").meta["meta_field"] == "test2"
|
|
|
|
assert not document_store.get_document_by_id("3").meta
|
|
|
|
assert document_store.get_document_by_id("4").meta["meta_field"] == "test4"
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-13 12:25:32 +02:00
|
|
|
def test_write_document_index(document_store):
|
|
|
|
documents = [
|
|
|
|
{"text": "text1", "id": "1"},
|
|
|
|
{"text": "text2", "id": "2"},
|
|
|
|
]
|
|
|
|
document_store.write_documents([documents[0]], index="haystack_test_1")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
|
|
|
|
|
|
|
if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
|
|
|
|
document_store.write_documents([documents[1]], index="haystack_test_2")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
|
|
|
|
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
|
|
|
assert len(document_store.get_all_documents()) == 0
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
2020-09-18 18:10:50 +02:00
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
|
|
|
def test_write_document_with_embeddings(document_store):
|
|
|
|
documents = [
|
|
|
|
{"text": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"text": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64)},
|
|
|
|
{"text": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
|
|
|
|
{"text": "text4", "id": "4", "embedding": None},
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents, index="haystack_test_1")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_1")) == 4
|
2020-08-13 12:25:32 +02:00
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
2020-08-07 14:25:08 +02:00
|
|
|
def test_labels(document_store):
|
|
|
|
label = Label(
|
|
|
|
question="question",
|
|
|
|
answer="answer",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="123",
|
|
|
|
offset_start_in_doc=12,
|
|
|
|
no_answer=False,
|
|
|
|
origin="gold_label",
|
|
|
|
)
|
|
|
|
document_store.write_labels([label], index="haystack_test_label")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert len(labels) == 1
|
|
|
|
|
|
|
|
labels = document_store.get_all_labels()
|
|
|
|
assert len(labels) == 0
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-17 20:14:31 +02:00
|
|
|
def test_multilabel(document_store):
|
|
|
|
labels =[
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="answer1",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="123",
|
|
|
|
offset_start_in_doc=12,
|
|
|
|
no_answer=False,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# different answer in same doc
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="answer2",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="123",
|
|
|
|
offset_start_in_doc=42,
|
|
|
|
no_answer=False,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# answer in different doc
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="answer3",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="321",
|
|
|
|
offset_start_in_doc=7,
|
|
|
|
no_answer=False,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# 'no answer', should be excluded from MultiLabel
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="777",
|
|
|
|
offset_start_in_doc=0,
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# is_correct_answer=False, should be excluded from MultiLabel
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="answer5",
|
|
|
|
is_correct_answer=False,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="123",
|
|
|
|
offset_start_in_doc=99,
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
]
|
|
|
|
document_store.write_labels(labels, index="haystack_test_multilabel")
|
|
|
|
multi_labels = document_store.get_all_labels_aggregated(index="haystack_test_multilabel")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_multilabel")
|
|
|
|
|
|
|
|
assert len(multi_labels) == 1
|
|
|
|
assert len(labels) == 5
|
|
|
|
|
|
|
|
assert len(multi_labels[0].multiple_answers) == 3
|
|
|
|
assert len(multi_labels[0].multiple_answers) \
|
|
|
|
== len(multi_labels[0].multiple_document_ids) \
|
|
|
|
== len(multi_labels[0].multiple_offset_start_in_docs)
|
|
|
|
|
|
|
|
multi_labels = document_store.get_all_labels_aggregated()
|
|
|
|
assert len(multi_labels) == 0
|
|
|
|
|
|
|
|
# clean up
|
|
|
|
document_store.delete_all_documents(index="haystack_test_multilabel")
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-18 18:25:01 +02:00
|
|
|
def test_multilabel_no_answer(document_store):
|
|
|
|
labels = [
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="777",
|
|
|
|
offset_start_in_doc=0,
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# no answer in different doc
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="123",
|
|
|
|
offset_start_in_doc=0,
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# no answer in same doc, should be excluded
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="",
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="777",
|
|
|
|
offset_start_in_doc=0,
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
# no answer with is_correct_answer=False, should be excluded
|
|
|
|
Label(
|
|
|
|
question="question",
|
|
|
|
answer="",
|
|
|
|
is_correct_answer=False,
|
|
|
|
is_correct_document=True,
|
|
|
|
document_id="321",
|
|
|
|
offset_start_in_doc=0,
|
|
|
|
no_answer=True,
|
|
|
|
origin="gold_label",
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
document_store.write_labels(labels, index="haystack_test_multilabel_no_answer")
|
|
|
|
multi_labels = document_store.get_all_labels_aggregated(index="haystack_test_multilabel_no_answer")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_multilabel_no_answer")
|
|
|
|
|
|
|
|
assert len(multi_labels) == 1
|
|
|
|
assert len(labels) == 4
|
|
|
|
|
|
|
|
assert len(multi_labels[0].multiple_document_ids) == 2
|
|
|
|
assert len(multi_labels[0].multiple_answers) \
|
|
|
|
== len(multi_labels[0].multiple_document_ids) \
|
|
|
|
== len(multi_labels[0].multiple_offset_start_in_docs)
|
|
|
|
|
|
|
|
# clean up
|
|
|
|
document_store.delete_all_documents(index="haystack_test_multilabel_no_answer")
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-09-18 12:22:52 +02:00
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql"], indirect=True)
|
|
|
|
def test_elasticsearch_update_meta(document_store):
|
|
|
|
documents = [
|
|
|
|
Document(
|
|
|
|
text="Doc1",
|
2020-10-02 14:43:25 +02:00
|
|
|
meta={"meta_key_1": "1", "meta_key_2": "1"}
|
2020-09-18 12:22:52 +02:00
|
|
|
),
|
|
|
|
Document(
|
|
|
|
text="Doc2",
|
2020-10-02 14:43:25 +02:00
|
|
|
meta={"meta_key_1": "2", "meta_key_2": "2"}
|
2020-09-18 12:22:52 +02:00
|
|
|
),
|
|
|
|
Document(
|
|
|
|
text="Doc3",
|
2020-10-02 14:43:25 +02:00
|
|
|
meta={"meta_key_1": "3", "meta_key_2": "3"}
|
2020-09-18 12:22:52 +02:00
|
|
|
)
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
2020-10-02 14:43:25 +02:00
|
|
|
document_2 = document_store.get_all_documents(filters={"meta_key_2": ["2"]})[0]
|
|
|
|
document_store.update_document_meta(document_2.id, meta={"meta_key_1": "99", "meta_key_2": "2"})
|
2020-09-18 12:22:52 +02:00
|
|
|
updated_document = document_store.get_document_by_id(document_2.id)
|
|
|
|
assert len(updated_document.meta.keys()) == 2
|
2020-10-02 14:43:25 +02:00
|
|
|
assert updated_document.meta["meta_key_1"] == "99"
|
|
|
|
assert updated_document.meta["meta_key_2"] == "2"
|
2020-08-10 05:34:39 -04:00
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2020-08-10 05:34:39 -04:00
|
|
|
def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index='haystack_test_custom', ignore=[404])
|
|
|
|
document_store = ElasticsearchDocumentStore(index="haystack_test_custom", text_field="custom_text_field",
|
|
|
|
embedding_field="custom_embedding_field")
|
|
|
|
|
|
|
|
doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
|
|
|
document_store.write_documents([doc_to_write])
|
|
|
|
documents = document_store.get_all_documents()
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert documents[0].text == "test"
|
|
|
|
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|