mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 11:38:20 +00:00
Add tests for custom embedding field (#640)
This commit is contained in:
parent
5b817387c2
commit
0e4eec9499
@ -234,24 +234,25 @@ def document_store(request, test_docs_xs):
|
||||
document_store.faiss_index.reset()
|
||||
|
||||
|
||||
def get_document_store(document_store_type):
|
||||
def get_document_store(document_store_type, embedding_field="embedding"):
|
||||
if document_store_type == "sql":
|
||||
if os.path.exists("haystack_test.db"):
|
||||
os.remove("haystack_test.db")
|
||||
document_store = SQLDocumentStore(url="sqlite:///haystack_test.db")
|
||||
elif document_store_type == "memory":
|
||||
document_store = InMemoryDocumentStore(return_embedding=True)
|
||||
document_store = InMemoryDocumentStore(return_embedding=True, embedding_field=embedding_field)
|
||||
elif document_store_type == "elasticsearch":
|
||||
# make sure we start from a fresh index
|
||||
client = Elasticsearch()
|
||||
client.indices.delete(index='haystack_test*', ignore=[404])
|
||||
document_store = ElasticsearchDocumentStore(index="haystack_test", return_embedding=True)
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
index="haystack_test", return_embedding=True, embedding_field=embedding_field
|
||||
)
|
||||
elif document_store_type == "faiss":
|
||||
if os.path.exists("haystack_test_faiss.db"):
|
||||
os.remove("haystack_test_faiss.db")
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url="sqlite:///haystack_test_faiss.db",
|
||||
return_embedding=True
|
||||
sql_url="sqlite:///haystack_test_faiss.db", return_embedding=True, embedding_field=embedding_field
|
||||
)
|
||||
return document_store
|
||||
else:
|
||||
|
||||
@ -2,6 +2,7 @@ import numpy as np
|
||||
import pytest
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from conftest import get_document_store
|
||||
from haystack import Document, Label
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
|
||||
@ -369,6 +370,20 @@ def test_elasticsearch_update_meta(document_store):
|
||||
assert updated_document.meta["meta_key_2"] == "2"
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
|
||||
def test_custom_embedding_field(document_store_type):
|
||||
document_store = get_document_store(
|
||||
document_store_type=document_store_type, embedding_field="custom_embedding_field"
|
||||
)
|
||||
doc_to_write = {"text": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
||||
document_store.write_documents([doc_to_write])
|
||||
documents = document_store.get_all_documents(return_embedding=True)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].text == "test"
|
||||
np.testing.assert_array_equal(doc_to_write["custom_embedding_field"], documents[0].embedding)
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
||||
client = Elasticsearch()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user