2022-01-03 11:38:02 +01:00
|
|
|
from unittest import mock
|
2020-08-10 05:34:39 -04:00
|
|
|
import numpy as np
|
2021-10-25 12:27:02 +02:00
|
|
|
import pandas as pd
|
2020-07-16 15:34:55 +02:00
|
|
|
import pytest
|
2022-01-03 11:38:02 +01:00
|
|
|
from unittest.mock import Mock
|
2020-08-10 05:34:39 -04:00
|
|
|
from elasticsearch import Elasticsearch
|
2021-11-19 19:20:23 +05:30
|
|
|
from elasticsearch.exceptions import RequestError
|
|
|
|
|
2020-07-16 15:34:55 +02:00
|
|
|
|
2020-12-17 09:18:57 +01:00
|
|
|
from conftest import get_document_store
|
2021-11-04 09:27:12 +01:00
|
|
|
from haystack.document_stores import WeaviateDocumentStore
|
2022-01-03 11:38:02 +01:00
|
|
|
from haystack.document_stores.base import BaseDocumentStore
|
2021-11-04 09:27:12 +01:00
|
|
|
from haystack.errors import DuplicateDocumentError
|
2021-10-25 15:50:23 +02:00
|
|
|
from haystack.schema import Document, Label, Answer, Span
|
|
|
|
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
|
|
|
from haystack.document_stores.faiss import FAISSDocumentStore
|
2020-07-14 09:53:31 +02:00
|
|
|
|
|
|
|
|
2021-02-19 14:29:59 +01:00
|
|
|
@pytest.mark.elasticsearch
|
|
|
|
def test_init_elastic_client():
|
|
|
|
# defaults
|
|
|
|
_ = ElasticsearchDocumentStore()
|
|
|
|
|
|
|
|
# list of hosts + single port
|
|
|
|
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=9200)
|
|
|
|
|
|
|
|
# list of hosts + list of ports (wrong)
|
|
|
|
with pytest.raises(Exception):
|
|
|
|
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=[9200])
|
|
|
|
|
|
|
|
# list of hosts + list
|
|
|
|
_ = ElasticsearchDocumentStore(host=["localhost", "127.0.0.1"], port=[9200, 9200])
|
|
|
|
|
|
|
|
# only api_key
|
|
|
|
with pytest.raises(Exception):
|
|
|
|
_ = ElasticsearchDocumentStore(host=["localhost"], port=[9200], api_key="test")
|
|
|
|
|
|
|
|
# api_key + id
|
|
|
|
_ = ElasticsearchDocumentStore(host=["localhost"], port=[9200], api_key="test", api_key_id="test")
|
|
|
|
|
|
|
|
|
2021-05-17 21:21:52 +05:30
|
|
|
def test_write_with_duplicate_doc_ids(document_store):
|
2022-01-03 16:58:19 +01:00
|
|
|
duplicate_documents = [
|
2021-05-17 21:21:52 +05:30
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc1",
|
2022-01-03 16:58:19 +01:00
|
|
|
id_hash_keys=["content"]
|
2021-05-17 21:21:52 +05:30
|
|
|
),
|
|
|
|
Document(
|
2022-01-03 16:58:19 +01:00
|
|
|
content="Doc1",
|
|
|
|
id_hash_keys=["content"]
|
2021-05-17 21:21:52 +05:30
|
|
|
)
|
|
|
|
]
|
2022-01-03 16:58:19 +01:00
|
|
|
document_store.write_documents(duplicate_documents, duplicate_documents="skip")
|
2021-10-13 14:23:23 +02:00
|
|
|
assert len(document_store.get_all_documents()) == 1
|
2021-05-17 21:21:52 +05:30
|
|
|
with pytest.raises(Exception):
|
2022-01-03 16:58:19 +01:00
|
|
|
document_store.write_documents(duplicate_documents, duplicate_documents="fail")
|
2021-05-17 21:21:52 +05:30
|
|
|
|
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus", "weaviate"], indirect=True)
|
2021-10-13 14:23:23 +02:00
|
|
|
def test_write_with_duplicate_doc_ids_custom_index(document_store):
|
2022-01-03 16:58:19 +01:00
|
|
|
duplicate_documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
Document(
|
|
|
|
content="Doc1",
|
2022-01-03 16:58:19 +01:00
|
|
|
id_hash_keys=["content"]
|
2021-10-13 14:23:23 +02:00
|
|
|
),
|
|
|
|
Document(
|
2022-01-03 16:58:19 +01:00
|
|
|
content="Doc1",
|
|
|
|
id_hash_keys=["content"]
|
2021-10-13 14:23:23 +02:00
|
|
|
)
|
|
|
|
]
|
|
|
|
document_store.delete_documents(index="haystack_custom_test")
|
2022-01-03 16:58:19 +01:00
|
|
|
document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="skip")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_custom_test")) == 1
|
2021-11-04 09:27:12 +01:00
|
|
|
with pytest.raises(DuplicateDocumentError):
|
2022-01-03 16:58:19 +01:00
|
|
|
document_store.write_documents(duplicate_documents, index="haystack_custom_test", duplicate_documents="fail")
|
2021-10-13 14:23:23 +02:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
# Weaviate manipulates document objects in-place when writing them to an index.
|
|
|
|
# It generates a uuid based on the provided id and the index name where the document is added to.
|
|
|
|
# We need to get rid of these generated uuids for this test and therefore reset the document objects.
|
|
|
|
# As a result, the documents will receive a fresh uuid based on their id_hash_keys and a different index name.
|
|
|
|
if isinstance(document_store, WeaviateDocumentStore):
|
2022-01-03 16:58:19 +01:00
|
|
|
duplicate_documents = [
|
2021-11-04 09:27:12 +01:00
|
|
|
Document(
|
|
|
|
content="Doc1",
|
2022-01-03 16:58:19 +01:00
|
|
|
id_hash_keys=["content"]
|
2021-11-04 09:27:12 +01:00
|
|
|
),
|
|
|
|
Document(
|
2022-01-03 16:58:19 +01:00
|
|
|
content="Doc1",
|
|
|
|
id_hash_keys=["content"]
|
2021-11-04 09:27:12 +01:00
|
|
|
)
|
|
|
|
]
|
2021-10-13 14:23:23 +02:00
|
|
|
# writing to the default, empty index should still work
|
2022-01-03 16:58:19 +01:00
|
|
|
document_store.write_documents(duplicate_documents, duplicate_documents="fail")
|
2021-10-13 14:23:23 +02:00
|
|
|
|
2021-10-29 13:52:28 +05:30
|
|
|
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_without_filters(document_store_with_docs):
|
2020-07-14 09:53:31 +02:00
|
|
|
documents = document_store_with_docs.get_all_documents()
|
|
|
|
assert all(isinstance(d, Document) for d in documents)
|
|
|
|
assert len(documents) == 3
|
|
|
|
assert {d.meta["name"] for d in documents} == {"filename1", "filename2", "filename3"}
|
|
|
|
assert {d.meta["meta_field"] for d in documents} == {"test1", "test2", "test3"}
|
2020-08-04 14:24:12 +02:00
|
|
|
|
|
|
|
|
2021-05-17 21:21:52 +05:30
|
|
|
def test_get_all_document_filter_duplicate_text_value(document_store):
|
2020-09-18 12:22:52 +02:00
|
|
|
documents = [
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc1",
|
2021-05-17 21:21:52 +05:30
|
|
|
meta={"f1": "0"},
|
2022-01-03 16:58:19 +01:00
|
|
|
id_hash_keys=["meta"]
|
2020-09-18 12:22:52 +02:00
|
|
|
),
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc1",
|
2021-05-17 21:21:52 +05:30
|
|
|
meta={"f1": "1", "meta_id": "0"},
|
2022-01-03 16:58:19 +01:00
|
|
|
id_hash_keys=["meta"]
|
2020-09-18 12:22:52 +02:00
|
|
|
),
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc2",
|
2021-05-17 21:21:52 +05:30
|
|
|
meta={"f3": "0"},
|
2022-01-03 16:58:19 +01:00
|
|
|
id_hash_keys=["meta"]
|
2020-09-18 12:22:52 +02:00
|
|
|
)
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
documents = document_store.get_all_documents(filters={"f1": ["1"]})
|
2021-10-13 14:23:23 +02:00
|
|
|
assert documents[0].content == "Doc1"
|
2020-09-18 12:22:52 +02:00
|
|
|
assert len(documents) == 1
|
2020-12-03 10:27:06 +01:00
|
|
|
assert {d.meta["meta_id"] for d in documents} == {"0"}
|
2020-09-18 12:22:52 +02:00
|
|
|
|
2022-01-03 16:58:19 +01:00
|
|
|
documents = document_store.get_all_documents(filters={"f1": ["0"]})
|
|
|
|
assert documents[0].content == "Doc1"
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert documents[0].meta.get("meta_id") is None
|
|
|
|
|
|
|
|
documents = document_store.get_all_documents(filters={"f3": ["0"]})
|
|
|
|
assert documents[0].content == "Doc2"
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert documents[0].meta.get("meta_id") is None
|
|
|
|
|
2020-09-18 12:22:52 +02:00
|
|
|
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_with_correct_filters(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]})
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert documents[0].meta["name"] == "filename2"
|
|
|
|
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test3"]})
|
|
|
|
assert len(documents) == 2
|
|
|
|
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
|
|
|
|
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
|
|
|
|
|
|
|
|
|
2021-09-27 10:52:07 +02:00
|
|
|
def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs):
|
|
|
|
document_store_with_docs = get_document_store("sql")
|
|
|
|
document_store_with_docs.write_documents(test_docs_xs)
|
|
|
|
|
2021-01-25 12:54:34 +01:00
|
|
|
document_store_with_docs.use_windowed_query = False
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test2"]})
|
|
|
|
assert len(documents) == 1
|
|
|
|
assert documents[0].meta["name"] == "filename2"
|
|
|
|
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test3"]})
|
|
|
|
assert len(documents) == 2
|
|
|
|
assert {d.meta["name"] for d in documents} == {"filename1", "filename3"}
|
|
|
|
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
|
|
|
|
|
|
|
|
|
2020-08-04 14:24:12 +02:00
|
|
|
def test_get_all_documents_with_incorrect_filter_name(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"incorrect_meta_field": ["test2"]})
|
|
|
|
assert len(documents) == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_all_documents_with_incorrect_filter_value(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents(filters={"meta_field": ["incorrect_value"]})
|
|
|
|
assert len(documents) == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_documents_by_id(document_store_with_docs):
|
|
|
|
documents = document_store_with_docs.get_all_documents()
|
2020-07-14 09:53:31 +02:00
|
|
|
doc = document_store_with_docs.get_document_by_id(documents[0].id)
|
|
|
|
assert doc.id == documents[0].id
|
2021-10-13 14:23:23 +02:00
|
|
|
assert doc.content == documents[0].content
|
2020-07-16 15:34:55 +02:00
|
|
|
|
|
|
|
|
2020-10-22 12:42:13 +02:00
|
|
|
def test_get_document_count(document_store):
|
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1", "id": "1", "meta_field_for_count": "a"},
|
|
|
|
{"content": "text2", "id": "2", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text3", "id": "3", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text4", "id": "4", "meta_field_for_count": "b"},
|
2020-10-22 12:42:13 +02:00
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
assert document_store.get_document_count() == 4
|
|
|
|
assert document_store.get_document_count(filters={"meta_field_for_count": ["a"]}) == 1
|
|
|
|
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
|
|
|
|
|
|
|
|
2021-01-21 16:00:08 +01:00
|
|
|
def test_get_all_documents_generator(document_store):
|
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1", "id": "1", "meta_field_for_count": "a"},
|
|
|
|
{"content": "text2", "id": "2", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text3", "id": "3", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text4", "id": "4", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text5", "id": "5", "meta_field_for_count": "b"},
|
2021-01-21 16:00:08 +01:00
|
|
|
]
|
|
|
|
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
assert len(list(document_store.get_all_documents_generator(batch_size=2))) == 5
|
|
|
|
|
|
|
|
|
2020-11-16 16:08:13 +01:00
|
|
|
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
|
|
|
def test_update_existing_documents(document_store, update_existing_documents):
|
|
|
|
original_docs = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1_orig", "id": "1", "meta_field_for_count": "a"},
|
2020-11-16 16:08:13 +01:00
|
|
|
]
|
|
|
|
|
|
|
|
updated_docs = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1_new", "id": "1", "meta_field_for_count": "a"},
|
2020-11-16 16:08:13 +01:00
|
|
|
]
|
|
|
|
|
|
|
|
document_store.write_documents(original_docs)
|
|
|
|
assert document_store.get_document_count() == 1
|
|
|
|
|
|
|
|
if update_existing_documents:
|
2021-05-25 16:30:06 +05:00
|
|
|
document_store.write_documents(updated_docs, duplicate_documents="overwrite")
|
2020-11-16 16:08:13 +01:00
|
|
|
else:
|
|
|
|
with pytest.raises(Exception):
|
2021-05-25 16:30:06 +05:00
|
|
|
document_store.write_documents(updated_docs, duplicate_documents="fail")
|
2020-11-16 16:08:13 +01:00
|
|
|
|
|
|
|
stored_docs = document_store.get_all_documents()
|
|
|
|
assert len(stored_docs) == 1
|
|
|
|
if update_existing_documents:
|
2021-10-13 14:23:23 +02:00
|
|
|
assert stored_docs[0].content == updated_docs[0]["content"]
|
2020-11-16 16:08:13 +01:00
|
|
|
else:
|
2021-10-13 14:23:23 +02:00
|
|
|
assert stored_docs[0].content == original_docs[0]["content"]
|
2020-11-16 16:08:13 +01:00
|
|
|
|
|
|
|
|
2020-08-13 12:25:32 +02:00
|
|
|
def test_write_document_meta(document_store):
|
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "dict_without_meta", "id": "1"},
|
|
|
|
{"content": "dict_with_meta", "meta_field": "test2", "name": "filename2", "id": "2"},
|
|
|
|
Document(content="document_object_without_meta", id="3"),
|
|
|
|
Document(content="document_object_with_meta", meta={"meta_field": "test4", "name": "filename3"}, id="4"),
|
2020-08-13 12:25:32 +02:00
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
documents_in_store = document_store.get_all_documents()
|
|
|
|
assert len(documents_in_store) == 4
|
|
|
|
|
|
|
|
assert not document_store.get_document_by_id("1").meta
|
|
|
|
assert document_store.get_document_by_id("2").meta["meta_field"] == "test2"
|
|
|
|
assert not document_store.get_document_by_id("3").meta
|
|
|
|
assert document_store.get_document_by_id("4").meta["meta_field"] == "test4"
|
|
|
|
|
|
|
|
|
|
|
|
def test_write_document_index(document_store):
|
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1", "id": "1"},
|
|
|
|
{"content": "text2", "id": "2"},
|
2020-08-13 12:25:32 +02:00
|
|
|
]
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents([documents[0]], index="haystack_test_one")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_one")) == 1
|
2020-08-13 12:25:32 +02:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents([documents[1]], index="haystack_test_two")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_two")) == 1
|
2020-08-13 12:25:32 +02:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_one")) == 1
|
2020-08-13 12:25:32 +02:00
|
|
|
assert len(document_store.get_all_documents()) == 0
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2020-11-26 10:32:30 +01:00
|
|
|
def test_document_with_embeddings(document_store):
|
2020-09-18 18:10:50 +02:00
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32)},
|
|
|
|
{"content": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64)},
|
|
|
|
{"content": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
|
|
|
|
{"content": "text4", "id": "4", "embedding": np.random.rand(768).astype(np.float32)},
|
2020-09-18 18:10:50 +02:00
|
|
|
]
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents(documents, index="haystack_test_one")
|
|
|
|
assert len(document_store.get_all_documents(index="haystack_test_one")) == 4
|
2020-08-13 12:25:32 +02:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
if not isinstance(document_store, WeaviateDocumentStore):
|
|
|
|
# weaviate is excluded because it would return dummy vectors instead of None
|
|
|
|
documents_without_embedding = document_store.get_all_documents(index="haystack_test_one", return_embedding=False)
|
|
|
|
assert documents_without_embedding[0].embedding is None
|
2020-11-26 10:32:30 +01:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
documents_with_embedding = document_store.get_all_documents(index="haystack_test_one", return_embedding=True)
|
2020-11-26 10:32:30 +01:00
|
|
|
assert isinstance(documents_with_embedding[0].embedding, (list, np.ndarray))
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
2021-01-21 16:00:08 +01:00
|
|
|
def test_update_embeddings(document_store, retriever):
|
|
|
|
documents = []
|
2021-02-09 21:25:01 +01:00
|
|
|
for i in range(6):
|
2021-10-13 14:23:23 +02:00
|
|
|
documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
|
|
|
documents.append({"content": "text_0", "id": "6", "meta_field": "value_0"})
|
2021-01-21 16:00:08 +01:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents(documents, index="haystack_test_one")
|
|
|
|
document_store.update_embeddings(retriever, index="haystack_test_one", batch_size=3)
|
|
|
|
documents = document_store.get_all_documents(index="haystack_test_one", return_embedding=True)
|
2021-02-09 21:25:01 +01:00
|
|
|
assert len(documents) == 7
|
2021-01-21 16:00:08 +01:00
|
|
|
for doc in documents:
|
|
|
|
assert type(doc.embedding) is np.ndarray
|
|
|
|
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-02-12 14:57:06 +01:00
|
|
|
filters={"meta_field": ["value_0"]},
|
2021-01-21 16:00:08 +01:00
|
|
|
return_embedding=True,
|
|
|
|
)
|
2021-02-12 14:57:06 +01:00
|
|
|
assert len(documents) == 2
|
|
|
|
for doc in documents:
|
|
|
|
assert doc.meta["meta_field"] == "value_0"
|
2021-02-15 14:52:13 +01:00
|
|
|
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
2021-01-21 16:00:08 +01:00
|
|
|
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-02-09 21:25:01 +01:00
|
|
|
filters={"meta_field": ["value_0", "value_5"]},
|
2021-01-21 16:00:08 +01:00
|
|
|
return_embedding=True,
|
|
|
|
)
|
2021-11-04 09:27:12 +01:00
|
|
|
documents_with_value_0 = [doc for doc in documents if doc.meta["meta_field"] == "value_0"]
|
|
|
|
documents_with_value_5 = [doc for doc in documents if doc.meta["meta_field"] == "value_5"]
|
2021-01-21 16:00:08 +01:00
|
|
|
np.testing.assert_raises(
|
|
|
|
AssertionError,
|
|
|
|
np.testing.assert_array_equal,
|
2021-11-04 09:27:12 +01:00
|
|
|
documents_with_value_0[0].embedding,
|
|
|
|
documents_with_value_5[0].embedding
|
2021-01-21 16:00:08 +01:00
|
|
|
)
|
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
doc = {"content": "text_7", "id": "7", "meta_field": "value_7",
|
2021-02-09 21:25:01 +01:00
|
|
|
"embedding": retriever.embed_queries(texts=["a random string"])[0]}
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents([doc], index="haystack_test_one")
|
2021-02-09 21:25:01 +01:00
|
|
|
|
|
|
|
documents = []
|
|
|
|
for i in range(8, 11):
|
2021-10-13 14:23:23 +02:00
|
|
|
documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents(documents, index="haystack_test_one")
|
2021-02-09 21:25:01 +01:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
doc_before_update = document_store.get_all_documents(index="haystack_test_one", filters={"meta_field": ["value_7"]})[0]
|
2021-02-09 21:25:01 +01:00
|
|
|
embedding_before_update = doc_before_update.embedding
|
|
|
|
|
|
|
|
# test updating only documents without embeddings
|
2021-11-04 09:27:12 +01:00
|
|
|
if not isinstance(document_store, WeaviateDocumentStore):
|
|
|
|
# All the documents in Weaviate store have an embedding by default. "update_existing_embeddings=False" is not allowed
|
|
|
|
document_store.update_embeddings(retriever, index="haystack_test_one", batch_size=3, update_existing_embeddings=False)
|
|
|
|
doc_after_update = document_store.get_all_documents(index="haystack_test_one", filters={"meta_field": ["value_7"]})[0]
|
|
|
|
embedding_after_update = doc_after_update.embedding
|
|
|
|
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
2021-02-09 21:25:01 +01:00
|
|
|
|
|
|
|
# test updating with filters
|
2021-04-21 09:56:35 +02:00
|
|
|
if isinstance(document_store, FAISSDocumentStore):
|
|
|
|
with pytest.raises(Exception):
|
|
|
|
document_store.update_embeddings(
|
2021-11-04 09:27:12 +01:00
|
|
|
retriever, index="haystack_test_one", update_existing_embeddings=True, filters={"meta_field": ["value"]}
|
2021-04-21 09:56:35 +02:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
document_store.update_embeddings(
|
2021-11-04 09:27:12 +01:00
|
|
|
retriever, index="haystack_test_one", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
|
2021-04-21 09:56:35 +02:00
|
|
|
)
|
2021-11-04 09:27:12 +01:00
|
|
|
doc_after_update = document_store.get_all_documents(index="haystack_test_one", filters={"meta_field": ["value_7"]})[0]
|
2021-04-21 09:56:35 +02:00
|
|
|
embedding_after_update = doc_after_update.embedding
|
|
|
|
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
2021-02-09 21:25:01 +01:00
|
|
|
|
|
|
|
# test update all embeddings
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.update_embeddings(retriever, index="haystack_test_one", batch_size=3, update_existing_embeddings=True)
|
|
|
|
assert document_store.get_embedding_count(index="haystack_test_one") == 11
|
|
|
|
doc_after_update = document_store.get_all_documents(index="haystack_test_one", filters={"meta_field": ["value_7"]})[0]
|
2021-02-09 21:25:01 +01:00
|
|
|
embedding_after_update = doc_after_update.embedding
|
|
|
|
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
|
|
|
|
|
2021-04-21 09:56:35 +02:00
|
|
|
# test update embeddings for newly added docs
|
|
|
|
documents = []
|
|
|
|
for i in range(12, 15):
|
2021-10-13 14:23:23 +02:00
|
|
|
documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents(documents, index="haystack_test_one")
|
|
|
|
|
|
|
|
if not isinstance(document_store, WeaviateDocumentStore):
|
|
|
|
# All the documents in Weaviate store have an embedding by default. "update_existing_embeddings=False" is not allowed
|
|
|
|
document_store.update_embeddings(retriever, index="haystack_test_one", batch_size=3, update_existing_embeddings=False)
|
|
|
|
assert document_store.get_embedding_count(index="haystack_test_one") == 14
|
2021-04-21 09:56:35 +02:00
|
|
|
|
2021-01-21 16:00:08 +01:00
|
|
|
|
2021-10-25 12:27:02 +02:00
|
|
|
@pytest.mark.parametrize("retriever", ["table_text_retriever"], indirect=True)
|
|
|
|
@pytest.mark.vector_dim(512)
|
|
|
|
def test_update_embeddings_table_text_retriever(document_store, retriever):
|
|
|
|
documents = []
|
|
|
|
for i in range(3):
|
|
|
|
documents.append({"content": f"text_{i}",
|
|
|
|
"id": f"pssg_{i}",
|
|
|
|
"meta_field": f"value_text_{i}",
|
|
|
|
"content_type": "text"})
|
|
|
|
documents.append({"content": pd.DataFrame(columns=[f"col_{i}", f"col_{i+1}"], data=[[f"cell_{i}", f"cell_{i+1}"]]),
|
|
|
|
"id": f"table_{i}",
|
|
|
|
f"meta_field": f"value_table_{i}",
|
|
|
|
"content_type": "table"})
|
|
|
|
documents.append({"content": "text_0",
|
|
|
|
"id": "pssg_4",
|
|
|
|
"meta_field": "value_text_0",
|
|
|
|
"content_type": "text"})
|
|
|
|
documents.append({"content": pd.DataFrame(columns=["col_0", "col_1"], data=[["cell_0", "cell_1"]]),
|
|
|
|
"id": "table_4",
|
|
|
|
"meta_field": "value_table_0",
|
|
|
|
"content_type": "table"})
|
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
document_store.write_documents(documents, index="haystack_test_one")
|
|
|
|
document_store.update_embeddings(retriever, index="haystack_test_one", batch_size=3)
|
|
|
|
documents = document_store.get_all_documents(index="haystack_test_one", return_embedding=True)
|
2021-10-25 12:27:02 +02:00
|
|
|
assert len(documents) == 8
|
|
|
|
for doc in documents:
|
|
|
|
assert type(doc.embedding) is np.ndarray
|
|
|
|
|
|
|
|
# Check if Documents with same content (text) get same embedding
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-10-25 12:27:02 +02:00
|
|
|
filters={"meta_field": ["value_text_0"]},
|
|
|
|
return_embedding=True,
|
|
|
|
)
|
|
|
|
assert len(documents) == 2
|
|
|
|
for doc in documents:
|
|
|
|
assert doc.meta["meta_field"] == "value_text_0"
|
|
|
|
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
|
|
|
|
|
|
|
# Check if Documents with same content (table) get same embedding
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-10-25 12:27:02 +02:00
|
|
|
filters={"meta_field": ["value_table_0"]},
|
|
|
|
return_embedding=True,
|
|
|
|
)
|
|
|
|
assert len(documents) == 2
|
|
|
|
for doc in documents:
|
|
|
|
assert doc.meta["meta_field"] == "value_table_0"
|
|
|
|
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
|
|
|
|
|
|
|
# Check if Documents wih different content (text) get different embedding
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-10-25 12:27:02 +02:00
|
|
|
filters={"meta_field": ["value_text_1", "value_text_2"]},
|
|
|
|
return_embedding=True,
|
|
|
|
)
|
|
|
|
np.testing.assert_raises(
|
|
|
|
AssertionError,
|
|
|
|
np.testing.assert_array_equal,
|
|
|
|
documents[0].embedding,
|
|
|
|
documents[1].embedding
|
|
|
|
)
|
|
|
|
|
|
|
|
# Check if Documents with different content (table) get different embeddings
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-10-25 12:27:02 +02:00
|
|
|
filters={"meta_field": ["value_table_1", "value_table_2"]},
|
|
|
|
return_embedding=True,
|
|
|
|
)
|
|
|
|
np.testing.assert_raises(
|
|
|
|
AssertionError,
|
|
|
|
np.testing.assert_array_equal,
|
|
|
|
documents[0].embedding,
|
|
|
|
documents[1].embedding
|
|
|
|
)
|
|
|
|
|
|
|
|
# Check if Documents with different content (table + text) get different embeddings
|
|
|
|
documents = document_store.get_all_documents(
|
2021-11-04 09:27:12 +01:00
|
|
|
index="haystack_test_one",
|
2021-10-25 12:27:02 +02:00
|
|
|
filters={"meta_field": ["value_text_1", "value_table_1"]},
|
|
|
|
return_embedding=True,
|
|
|
|
)
|
|
|
|
np.testing.assert_raises(
|
|
|
|
AssertionError,
|
|
|
|
np.testing.assert_array_equal,
|
|
|
|
documents[0].embedding,
|
|
|
|
documents[1].embedding
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-01-22 14:39:24 +01:00
|
|
|
def test_delete_all_documents(document_store_with_docs):
|
2021-01-29 13:29:12 +01:00
|
|
|
assert len(document_store_with_docs.get_all_documents()) == 3
|
2021-01-22 14:39:24 +01:00
|
|
|
|
2021-08-30 18:48:28 +05:30
|
|
|
document_store_with_docs.delete_documents()
|
2021-01-29 13:29:12 +01:00
|
|
|
documents = document_store_with_docs.get_all_documents()
|
2021-01-22 14:39:24 +01:00
|
|
|
assert len(documents) == 0
|
|
|
|
|
2020-11-16 14:15:32 +01:00
|
|
|
|
2021-05-10 16:37:08 +05:00
|
|
|
def test_delete_documents(document_store_with_docs):
|
|
|
|
assert len(document_store_with_docs.get_all_documents()) == 3
|
|
|
|
|
|
|
|
document_store_with_docs.delete_documents()
|
|
|
|
documents = document_store_with_docs.get_all_documents()
|
|
|
|
assert len(documents) == 0
|
|
|
|
|
2021-10-19 12:30:15 +02:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
def test_delete_documents_with_filters(document_store_with_docs):
|
|
|
|
document_store_with_docs.delete_documents(filters={"meta_field": ["test1", "test2"]})
|
2021-10-19 12:30:15 +02:00
|
|
|
documents = document_store_with_docs.get_all_documents()
|
|
|
|
assert len(documents) == 1
|
2021-11-04 09:27:12 +01:00
|
|
|
assert documents[0].meta["meta_field"] == "test3"
|
|
|
|
|
|
|
|
|
|
|
|
def test_delete_documents_by_id(document_store_with_docs):
|
|
|
|
docs_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2"]})
|
|
|
|
docs_not_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test3"]})
|
|
|
|
|
|
|
|
document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete])
|
|
|
|
all_docs_left = document_store_with_docs.get_all_documents()
|
|
|
|
assert len(all_docs_left) == 1
|
|
|
|
assert all_docs_left[0].meta["meta_field"] == "test3"
|
|
|
|
|
|
|
|
all_ids_left = [doc.id for doc in all_docs_left]
|
|
|
|
assert all(doc.id in all_ids_left for doc in docs_not_to_delete)
|
2021-10-19 12:30:15 +02:00
|
|
|
|
|
|
|
|
|
|
|
def test_delete_documents_by_id_with_filters(document_store_with_docs):
|
|
|
|
docs_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test1", "test2"]})
|
|
|
|
docs_not_to_delete = document_store_with_docs.get_all_documents(filters={"meta_field": ["test3"]})
|
|
|
|
|
|
|
|
document_store_with_docs.delete_documents(ids=[doc.id for doc in docs_to_delete], filters={"meta_field": ["test1"]})
|
|
|
|
|
|
|
|
all_docs_left = document_store_with_docs.get_all_documents()
|
|
|
|
assert len(all_docs_left) == 2
|
|
|
|
assert all(doc.meta["meta_field"] != "test1" for doc in all_docs_left)
|
|
|
|
|
|
|
|
all_ids_left = [doc.id for doc in all_docs_left]
|
|
|
|
assert all(doc.id in all_ids_left for doc in docs_not_to_delete)
|
|
|
|
|
2020-11-16 14:15:32 +01:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
# exclude weaviate because it does not support storing labels
|
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
2020-08-07 14:25:08 +02:00
|
|
|
def test_labels(document_store):
|
|
|
|
label = Label(
|
2021-10-19 17:20:28 +02:00
|
|
|
query="question1",
|
2021-10-13 14:23:23 +02:00
|
|
|
answer=Answer(answer="answer",
|
|
|
|
type="extractive",
|
|
|
|
score=0.0,
|
|
|
|
context="something",
|
|
|
|
offsets_in_document=[Span(start=12, end=14)],
|
|
|
|
offsets_in_context=[Span(start=12, end=14)],
|
|
|
|
),
|
2020-08-07 14:25:08 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
document=Document(content="something", id="123"),
|
2020-08-07 14:25:08 +02:00
|
|
|
no_answer=False,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-07 14:25:08 +02:00
|
|
|
)
|
|
|
|
document_store.write_labels([label], index="haystack_test_label")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert len(labels) == 1
|
2021-10-13 14:23:23 +02:00
|
|
|
assert label == labels[0]
|
2020-08-07 14:25:08 +02:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
# different index
|
2020-08-07 14:25:08 +02:00
|
|
|
labels = document_store.get_all_labels()
|
|
|
|
assert len(labels) == 0
|
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
# write second label + duplicate
|
|
|
|
label2 = Label(
|
2021-10-19 17:20:28 +02:00
|
|
|
query="question2",
|
2021-10-13 14:23:23 +02:00
|
|
|
answer=Answer(answer="another answer",
|
|
|
|
type="extractive",
|
|
|
|
score=0.0,
|
|
|
|
context="something",
|
|
|
|
offsets_in_document=[Span(start=12, end=14)],
|
|
|
|
offsets_in_context=[Span(start=12, end=14)],
|
|
|
|
),
|
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
document=Document(content="something", id="324"),
|
|
|
|
no_answer=False,
|
|
|
|
origin="gold-label",
|
|
|
|
)
|
|
|
|
document_store.write_labels([label, label2], index="haystack_test_label")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
|
2021-10-19 17:20:28 +02:00
|
|
|
# check that second label has been added but not the duplicate
|
2021-10-13 14:23:23 +02:00
|
|
|
assert len(labels) == 2
|
|
|
|
assert label in labels
|
|
|
|
assert label2 in labels
|
|
|
|
|
2021-10-19 17:20:28 +02:00
|
|
|
# delete filtered label2 by id
|
|
|
|
document_store.delete_labels(index="haystack_test_label", ids=[labels[1].id])
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert label == labels[0]
|
|
|
|
assert len(labels) == 1
|
|
|
|
|
|
|
|
# re-add label2
|
|
|
|
document_store.write_labels([label2], index="haystack_test_label")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert len(labels) == 2
|
|
|
|
|
|
|
|
# delete filtered label2 by query text
|
|
|
|
document_store.delete_labels(index="haystack_test_label", filters={"query": [labels[1].query]})
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert label == labels[0]
|
|
|
|
assert len(labels) == 1
|
|
|
|
|
|
|
|
# re-add label2
|
|
|
|
document_store.write_labels([label2], index="haystack_test_label")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert len(labels) == 2
|
|
|
|
|
|
|
|
# delete intersection of filters and ids, which is empty
|
|
|
|
document_store.delete_labels(index="haystack_test_label", ids=[labels[0].id], filters={"query": [labels[1].query]})
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert len(labels) == 2
|
|
|
|
assert label in labels
|
|
|
|
assert label2 in labels
|
|
|
|
|
|
|
|
# delete all labels
|
|
|
|
document_store.delete_labels(index="haystack_test_label")
|
|
|
|
labels = document_store.get_all_labels(index="haystack_test_label")
|
|
|
|
assert len(labels) == 0
|
|
|
|
|
2020-08-07 14:25:08 +02:00
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
# exclude weaviate because it does not support storing labels
|
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
2020-08-17 20:14:31 +02:00
|
|
|
def test_multilabel(document_store):
|
|
|
|
labels =[
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
id="standard",
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="answer1",
|
|
|
|
offsets_in_document=[Span(start=12, end=18)]),
|
|
|
|
document=Document(content="some", id="123"),
|
2020-08-17 20:14:31 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
no_answer=False,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-17 20:14:31 +02:00
|
|
|
),
|
|
|
|
# different answer in same doc
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
id="diff-answer-same-doc",
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="answer2",
|
|
|
|
offsets_in_document=[Span(start=12, end=18)]),
|
|
|
|
document=Document(content="some", id="123"),
|
2020-08-17 20:14:31 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
no_answer=False,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-17 20:14:31 +02:00
|
|
|
),
|
|
|
|
# answer in different doc
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
id="diff-answer-diff-doc",
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="answer3",
|
|
|
|
offsets_in_document=[Span(start=12, end=18)]),
|
|
|
|
document=Document(content="some other", id="333"),
|
2020-08-17 20:14:31 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
no_answer=False,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-17 20:14:31 +02:00
|
|
|
),
|
|
|
|
# 'no answer', should be excluded from MultiLabel
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
id="4-no-answer",
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="",
|
|
|
|
offsets_in_document=[Span(start=0, end=0)]),
|
|
|
|
document=Document(content="some", id="777"),
|
2020-08-17 20:14:31 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
|
|
|
no_answer=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-17 20:14:31 +02:00
|
|
|
),
|
2021-10-13 14:23:23 +02:00
|
|
|
# is_correct_answer=False, should be excluded from MultiLabel if "drop_negatives = True"
|
2020-08-17 20:14:31 +02:00
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
id="5-negative",
|
|
|
|
query="question",
|
|
|
|
answer=Answer(answer="answer5",
|
|
|
|
offsets_in_document=[Span(start=12, end=18)]),
|
|
|
|
document=Document(content="some", id="123"),
|
2020-08-17 20:14:31 +02:00
|
|
|
is_correct_answer=False,
|
|
|
|
is_correct_document=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
no_answer=False,
|
|
|
|
origin="gold-label",
|
2020-08-17 20:14:31 +02:00
|
|
|
),
|
|
|
|
]
|
|
|
|
document_store.write_labels(labels, index="haystack_test_multilabel")
|
2021-10-13 14:23:23 +02:00
|
|
|
# regular labels - not aggregated
|
|
|
|
list_labels = document_store.get_all_labels(index="haystack_test_multilabel")
|
|
|
|
assert list_labels == labels
|
|
|
|
assert len(list_labels) == 5
|
2020-08-17 20:14:31 +02:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
# Currently we don't enforce writing (missing) docs automatically when adding labels and there's no DB relationship between the two.
|
|
|
|
# We should introduce this when we refactored the logic of "index" to be rather a "collection" of labels+documents
|
|
|
|
# docs = document_store.get_all_documents(index="haystack_test_multilabel")
|
|
|
|
# assert len(docs) == 3
|
|
|
|
|
|
|
|
# Multi labels (open domain)
|
|
|
|
multi_labels_open = document_store.get_all_labels_aggregated(index="haystack_test_multilabel",
|
|
|
|
open_domain=True, drop_negative_labels=True)
|
|
|
|
|
|
|
|
# for open-domain we group all together as long as they have the same question
|
2021-06-02 12:09:03 +02:00
|
|
|
assert len(multi_labels_open) == 1
|
2021-11-30 19:26:34 +01:00
|
|
|
# all labels are in there except the negative one and the no_answer
|
|
|
|
assert len(multi_labels_open[0].labels) == 4
|
|
|
|
assert len(multi_labels_open[0].answers) == 3
|
2021-10-13 14:23:23 +02:00
|
|
|
assert "5-negative" not in [l.id for l in multi_labels_open[0].labels]
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels_open[0].document_ids) == 3
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
# Don't drop the negative label
|
|
|
|
multi_labels_open = document_store.get_all_labels_aggregated(index="haystack_test_multilabel", open_domain=True,
|
|
|
|
drop_no_answers=False, drop_negative_labels=False)
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels_open[0].labels) == 5
|
|
|
|
assert len(multi_labels_open[0].answers) == 4
|
|
|
|
assert len(multi_labels_open[0].document_ids) == 4
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
# Drop no answer + negative
|
|
|
|
multi_labels_open = document_store.get_all_labels_aggregated(index="haystack_test_multilabel", open_domain=True,
|
|
|
|
drop_no_answers=True, drop_negative_labels=True)
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels_open[0].labels) == 3
|
2021-10-13 14:23:23 +02:00
|
|
|
assert len(multi_labels_open[0].answers) == 3
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels_open[0].document_ids) == 3
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
# for closed domain we group by document so we expect 3 multilabels with 2,1,1 labels each (negative dropped again)
|
|
|
|
multi_labels = document_store.get_all_labels_aggregated(index="haystack_test_multilabel",
|
|
|
|
open_domain=False, drop_negative_labels=True)
|
2021-06-02 12:09:03 +02:00
|
|
|
assert len(multi_labels) == 3
|
2021-10-13 14:23:23 +02:00
|
|
|
label_counts = set([len(ml.labels) for ml in multi_labels])
|
|
|
|
assert label_counts == set([2,1,1])
|
2020-08-17 20:14:31 +02:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
assert len(multi_labels[0].answers) == len(multi_labels[0].document_ids)
|
2020-08-17 20:14:31 +02:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
# make sure there' nothing stored in another index
|
2020-08-17 20:14:31 +02:00
|
|
|
multi_labels = document_store.get_all_labels_aggregated()
|
|
|
|
assert len(multi_labels) == 0
|
2021-10-13 14:23:23 +02:00
|
|
|
docs = document_store.get_all_documents()
|
|
|
|
assert len(docs) == 0
|
2020-08-17 20:14:31 +02:00
|
|
|
|
|
|
|
|
2021-11-04 09:27:12 +01:00
|
|
|
# exclude weaviate because it does not support storing labels
|
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "memory", "milvus"], indirect=True)
|
2020-08-18 18:25:01 +02:00
|
|
|
def test_multilabel_no_answer(document_store):
|
|
|
|
labels = [
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
query="question",
|
|
|
|
answer=Answer(answer=""),
|
2020-08-18 18:25:01 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
document=Document(content="some", id="777"),
|
2020-08-18 18:25:01 +02:00
|
|
|
no_answer=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-18 18:25:01 +02:00
|
|
|
),
|
|
|
|
# no answer in different doc
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
query="question",
|
|
|
|
answer=Answer(answer=""),
|
2020-08-18 18:25:01 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
document=Document(content="some", id="123"),
|
2020-08-18 18:25:01 +02:00
|
|
|
no_answer=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-18 18:25:01 +02:00
|
|
|
),
|
|
|
|
# no answer in same doc, should be excluded
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
query="question",
|
|
|
|
answer=Answer(answer=""),
|
2020-08-18 18:25:01 +02:00
|
|
|
is_correct_answer=True,
|
|
|
|
is_correct_document=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
document=Document(content="some", id="777"),
|
2020-08-18 18:25:01 +02:00
|
|
|
no_answer=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-18 18:25:01 +02:00
|
|
|
),
|
|
|
|
# no answer with is_correct_answer=False, should be excluded
|
|
|
|
Label(
|
2021-10-13 14:23:23 +02:00
|
|
|
query="question",
|
|
|
|
answer=Answer(answer=""),
|
2020-08-18 18:25:01 +02:00
|
|
|
is_correct_answer=False,
|
|
|
|
is_correct_document=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
document=Document(content="some", id="777"),
|
2020-08-18 18:25:01 +02:00
|
|
|
no_answer=True,
|
2021-10-13 14:23:23 +02:00
|
|
|
origin="gold-label",
|
2020-08-18 18:25:01 +02:00
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
document_store.write_labels(labels, index="haystack_test_multilabel_no_answer")
|
|
|
|
|
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
labels = document_store.get_all_labels(index="haystack_test_multilabel_no_answer")
|
|
|
|
assert len(labels) == 4
|
2020-08-18 18:25:01 +02:00
|
|
|
|
2021-10-13 14:23:23 +02:00
|
|
|
multi_labels = document_store.get_all_labels_aggregated(index="haystack_test_multilabel_no_answer",
|
|
|
|
open_domain=True,
|
|
|
|
drop_no_answers=False,
|
|
|
|
drop_negative_labels=True)
|
|
|
|
assert len(multi_labels) == 1
|
|
|
|
assert multi_labels[0].no_answer == True
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels[0].document_ids) == 0
|
|
|
|
assert len(multi_labels[0].answers) == 1
|
2021-10-13 14:23:23 +02:00
|
|
|
|
|
|
|
multi_labels = document_store.get_all_labels_aggregated(index="haystack_test_multilabel_no_answer",
|
|
|
|
open_domain=True,
|
|
|
|
drop_no_answers=False,
|
|
|
|
drop_negative_labels=False)
|
|
|
|
assert len(multi_labels) == 1
|
|
|
|
assert multi_labels[0].no_answer == True
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels[0].document_ids) == 0
|
2021-10-13 14:23:23 +02:00
|
|
|
assert len(multi_labels[0].labels) == 3
|
2021-11-30 19:26:34 +01:00
|
|
|
assert len(multi_labels[0].answers) == 1
|
2020-08-18 18:25:01 +02:00
|
|
|
|
|
|
|
|
2021-09-27 10:52:07 +02:00
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True)
|
|
|
|
# Currently update_document_meta() is not implemented for Memory doc store
|
2021-01-22 14:39:24 +01:00
|
|
|
def test_update_meta(document_store):
|
2020-09-18 12:22:52 +02:00
|
|
|
documents = [
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc1",
|
2020-10-02 14:43:25 +02:00
|
|
|
meta={"meta_key_1": "1", "meta_key_2": "1"}
|
2020-09-18 12:22:52 +02:00
|
|
|
),
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc2",
|
2020-10-02 14:43:25 +02:00
|
|
|
meta={"meta_key_1": "2", "meta_key_2": "2"}
|
2020-09-18 12:22:52 +02:00
|
|
|
),
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc3",
|
2020-10-02 14:43:25 +02:00
|
|
|
meta={"meta_key_1": "3", "meta_key_2": "3"}
|
2020-09-18 12:22:52 +02:00
|
|
|
)
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
2020-10-02 14:43:25 +02:00
|
|
|
document_2 = document_store.get_all_documents(filters={"meta_key_2": ["2"]})[0]
|
|
|
|
document_store.update_document_meta(document_2.id, meta={"meta_key_1": "99", "meta_key_2": "2"})
|
2020-09-18 12:22:52 +02:00
|
|
|
updated_document = document_store.get_document_by_id(document_2.id)
|
|
|
|
assert len(updated_document.meta.keys()) == 2
|
2020-10-02 14:43:25 +02:00
|
|
|
assert updated_document.meta["meta_key_1"] == "99"
|
|
|
|
assert updated_document.meta["meta_key_2"] == "2"
|
2020-08-10 05:34:39 -04:00
|
|
|
|
|
|
|
|
2020-12-17 09:18:57 +01:00
|
|
|
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
|
|
|
|
def test_custom_embedding_field(document_store_type):
|
|
|
|
document_store = get_document_store(
|
|
|
|
document_store_type=document_store_type, embedding_field="custom_embedding_field"
|
|
|
|
)
|
2021-10-13 14:23:23 +02:00
|
|
|
doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
2020-12-17 09:18:57 +01:00
|
|
|
document_store.write_documents([doc_to_write])
|
|
|
|
documents = document_store.get_all_documents(return_embedding=True)
|
|
|
|
assert len(documents) == 1
|
2021-10-13 14:23:23 +02:00
|
|
|
assert documents[0].content == "test"
|
2020-12-17 09:18:57 +01:00
|
|
|
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|
|
|
|
|
|
|
|
|
2021-02-01 16:13:26 +01:00
|
|
|
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
|
|
|
def test_get_meta_values_by_key(document_store):
|
|
|
|
documents = [
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc1",
|
2021-02-01 16:13:26 +01:00
|
|
|
meta={"meta_key_1": "1", "meta_key_2": "11"}
|
|
|
|
),
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc2",
|
2021-02-01 16:13:26 +01:00
|
|
|
meta={"meta_key_1": "2", "meta_key_2": "22"}
|
|
|
|
),
|
|
|
|
Document(
|
2021-10-13 14:23:23 +02:00
|
|
|
content="Doc3",
|
2021-02-01 16:13:26 +01:00
|
|
|
meta={"meta_key_1": "3", "meta_key_2": "33"}
|
|
|
|
)
|
|
|
|
]
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
|
|
|
|
# test without filters or query
|
|
|
|
result = document_store.get_metadata_values_by_key(key="meta_key_1")
|
|
|
|
for bucket in result:
|
|
|
|
assert bucket["value"] in ["1", "2", "3"]
|
|
|
|
assert bucket["count"] == 1
|
|
|
|
|
|
|
|
# test with filters but no query
|
|
|
|
result = document_store.get_metadata_values_by_key(key="meta_key_1", filters={"meta_key_2": ["11", "22"]})
|
|
|
|
for bucket in result:
|
|
|
|
assert bucket["value"] in ["1", "2"]
|
|
|
|
assert bucket["count"] == 1
|
|
|
|
|
|
|
|
# test with filters & query
|
|
|
|
result = document_store.get_metadata_values_by_key(key="meta_key_1", query="Doc1")
|
|
|
|
for bucket in result:
|
|
|
|
assert bucket["value"] in ["1"]
|
|
|
|
assert bucket["count"] == 1
|
|
|
|
|
|
|
|
|
2020-10-30 18:06:02 +01:00
|
|
|
@pytest.mark.elasticsearch
|
2021-10-29 13:52:28 +05:30
|
|
|
def test_elasticsearch_custom_fields():
|
2020-08-10 05:34:39 -04:00
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index='haystack_test_custom', ignore=[404])
|
2021-10-13 14:23:23 +02:00
|
|
|
document_store = ElasticsearchDocumentStore(index="haystack_test_custom", content_field="custom_text_field",
|
2020-08-10 05:34:39 -04:00
|
|
|
embedding_field="custom_embedding_field")
|
|
|
|
|
|
|
|
doc_to_write = {"custom_text_field": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
|
|
|
document_store.write_documents([doc_to_write])
|
2020-11-26 10:32:30 +01:00
|
|
|
documents = document_store.get_all_documents(return_embedding=True)
|
2020-08-10 05:34:39 -04:00
|
|
|
assert len(documents) == 1
|
2021-10-13 14:23:23 +02:00
|
|
|
assert documents[0].content == "test"
|
2020-08-10 05:34:39 -04:00
|
|
|
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|
2021-05-21 17:18:07 +05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
|
|
def test_get_document_count_only_documents_without_embedding_arg():
|
|
|
|
documents = [
|
2021-10-13 14:23:23 +02:00
|
|
|
{"content": "text1", "id": "1", "embedding": np.random.rand(768).astype(np.float32), "meta_field_for_count": "a"},
|
|
|
|
{"content": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "b"},
|
|
|
|
{"content": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
|
|
|
|
{"content": "text4", "id": "4", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text5", "id": "5", "meta_field_for_count": "b"},
|
|
|
|
{"content": "text6", "id": "6", "meta_field_for_count": "c"},
|
|
|
|
{"content": "text7", "id": "7", "embedding": np.random.rand(768).astype(np.float64), "meta_field_for_count": "c"},
|
2021-05-21 17:18:07 +05:00
|
|
|
]
|
|
|
|
|
|
|
|
_index: str = "haystack_test_count"
|
|
|
|
document_store = ElasticsearchDocumentStore(index=_index)
|
|
|
|
document_store.delete_documents(index=_index)
|
|
|
|
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
|
|
|
|
assert document_store.get_document_count() == 7
|
|
|
|
assert document_store.get_document_count(only_documents_without_embedding=True) == 3
|
|
|
|
assert document_store.get_document_count(only_documents_without_embedding=True,
|
|
|
|
filters={"meta_field_for_count": ["c"]}) == 1
|
|
|
|
assert document_store.get_document_count(only_documents_without_embedding=True,
|
|
|
|
filters={"meta_field_for_count": ["b"]}) == 2
|
2021-11-19 19:20:23 +05:30
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
|
|
def test_skip_missing_embeddings():
|
|
|
|
documents = [
|
|
|
|
{"content": "text1", "id": "1"}, # a document without embeddings
|
|
|
|
{"content": "text2", "id": "2", "embedding": np.random.rand(768).astype(np.float64)},
|
|
|
|
{"content": "text3", "id": "3", "embedding": np.random.rand(768).astype(np.float32).tolist()},
|
|
|
|
{"content": "text4", "id": "4", "embedding": np.random.rand(768).astype(np.float32)}
|
|
|
|
]
|
|
|
|
document_store = ElasticsearchDocumentStore(index="skip_missing_embedding_index")
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
|
|
|
|
document_store.skip_missing_embeddings = True
|
|
|
|
retrieved_docs = document_store.query_by_embedding(np.random.rand(768).astype(np.float32))
|
|
|
|
assert len(retrieved_docs) == 3
|
|
|
|
|
|
|
|
document_store.skip_missing_embeddings = False
|
|
|
|
with pytest.raises(RequestError):
|
|
|
|
document_store.query_by_embedding(np.random.rand(768).astype(np.float32))
|
|
|
|
|
|
|
|
# Test scenario with no embeddings for the entire index
|
|
|
|
documents = [
|
|
|
|
{"content": "text1", "id": "1"},
|
|
|
|
{"content": "text2", "id": "2"},
|
|
|
|
{"content": "text3", "id": "3"},
|
|
|
|
{"content": "text4", "id": "4"}
|
|
|
|
]
|
|
|
|
|
|
|
|
document_store.delete_documents()
|
|
|
|
document_store.write_documents(documents)
|
|
|
|
|
|
|
|
document_store.skip_missing_embeddings = True
|
|
|
|
with pytest.raises(RequestError):
|
|
|
|
document_store.query_by_embedding(np.random.rand(768).astype(np.float32))
|
2021-11-23 23:40:34 +05:30
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.elasticsearch
|
|
|
|
def test_elasticsearch_synonyms():
|
|
|
|
synonyms = ["i-pod, i pod, ipod", "sea biscuit, sea biscit, seabiscuit", "foo, foo bar, baz"]
|
|
|
|
synonym_type = "synonym_graph"
|
|
|
|
|
|
|
|
client = Elasticsearch()
|
|
|
|
client.indices.delete(index='haystack_synonym_arg', ignore=[404])
|
|
|
|
document_store = ElasticsearchDocumentStore(index="haystack_synonym_arg", synonyms=synonyms,
|
|
|
|
synonym_type=synonym_type)
|
|
|
|
indexed_settings = client.indices.get_settings(index="haystack_synonym_arg")
|
|
|
|
|
|
|
|
assert synonym_type == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['type']
|
2022-01-03 11:38:02 +01:00
|
|
|
assert synonyms == indexed_settings['haystack_synonym_arg']['settings']['index']['analysis']['filter']['synonym']['synonyms']
|
|
|
|
|
|
|
|
|
|
|
|
def test_custom_headers(document_store_with_docs: BaseDocumentStore):
|
|
|
|
mock_client = None
|
|
|
|
if isinstance(document_store_with_docs, ElasticsearchDocumentStore):
|
|
|
|
es_document_store: ElasticsearchDocumentStore = document_store_with_docs
|
|
|
|
mock_client = Mock(wraps=es_document_store.client)
|
|
|
|
es_document_store.client = mock_client
|
|
|
|
custom_headers = {'X-My-Custom-Header': 'header-value'}
|
|
|
|
if not mock_client:
|
|
|
|
with pytest.raises(NotImplementedError):
|
|
|
|
documents = document_store_with_docs.get_all_documents(headers=custom_headers)
|
|
|
|
else:
|
|
|
|
documents = document_store_with_docs.get_all_documents(headers=custom_headers)
|
|
|
|
mock_client.search.assert_called_once()
|
|
|
|
args, kwargs = mock_client.search.call_args
|
|
|
|
assert "headers" in kwargs
|
|
|
|
assert kwargs["headers"] == custom_headers
|
|
|
|
assert len(documents) > 0
|