Add tests for custom embedding field (#640)

This commit is contained in:
Tanay Soni 2020-12-17 09:18:57 +01:00 committed by GitHub
parent 5b817387c2
commit 0e4eec9499
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 5 deletions

View File

@ -234,24 +234,25 @@ def document_store(request, test_docs_xs):
document_store.faiss_index.reset()
def get_document_store(document_store_type):
def get_document_store(document_store_type, embedding_field="embedding"):
if document_store_type == "sql":
if os.path.exists("haystack_test.db"):
os.remove("haystack_test.db")
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
elif document_store_type == "memory":
document_store = InMemoryDocumentStore(return_embedding=True)
document_store = InMemoryDocumentStore(return_embedding=True, embedding_field=embedding_field)
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
client.indices.delete(index='haystack_test*', ignore=[404])
document_store = ElasticsearchDocumentStore(index="haystack_test", return_embedding=True)
document_store = ElasticsearchDocumentStore(
index="haystack_test", return_embedding=True, embedding_field=embedding_field
)
elif document_store_type == "faiss":
if os.path.exists("haystack_test_faiss.db"):
os.remove("haystack_test_faiss.db")
document_store = FAISSDocumentStore(
sql_url="sqlite:///haystack_test_faiss.db",
return_embedding=True
sql_url="sqlite:///haystack_test_faiss.db", return_embedding=True, embedding_field=embedding_field
)
return document_store
else:

View File

@ -2,6 +2,7 @@ import numpy as np
import pytest
from elasticsearch import Elasticsearch
from conftest import get_document_store
from haystack import Document, Label
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
@ -369,6 +370,20 @@ def test_elasticsearch_update_meta(document_store):
assert updated_document.meta["meta_key_2"] == "2"
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
def test_custom_embedding_field(document_store_type):
document_store = get_document_store(
document_store_type=document_store_type, embedding_field="custom_embedding_field"
)
doc_to_write = {"text": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
document_store.write_documents([doc_to_write])
documents = document_store.get_all_documents(return_embedding=True)
assert len(documents) == 1
assert documents[0].text == "test"
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
@pytest.mark.elasticsearch
def test_elasticsearch_custom_fields(elasticsearch_fixture):
client = Elasticsearch()