mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-29 10:56:40 +00:00
167 lines
7.5 KiB
Python
167 lines
7.5 KiB
Python
import pytest
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from haystack.nodes import EmbeddingRetriever, TableTextRetriever
|
|
|
|
from ..conftest import document_store
|
|
|
|
|
|
@pytest.mark.parametrize("name", ["elasticsearch", "faiss", "memory"])
|
|
def test_update_embeddings(name, tmp_path):
|
|
documents = []
|
|
for i in range(6):
|
|
documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
|
documents.append({"content": "text_0", "id": "6", "meta_field": "value_0"})
|
|
|
|
with document_store(name, documents, tmp_path) as ds:
|
|
retriever = EmbeddingRetriever(document_store=ds, embedding_model="deepset/sentence_bert", use_gpu=False)
|
|
|
|
ds.update_embeddings(retriever, batch_size=3)
|
|
documents = ds.get_all_documents(return_embedding=True)
|
|
assert len(documents) == 7
|
|
for doc in documents:
|
|
assert type(doc.embedding) is np.ndarray
|
|
|
|
documents = ds.get_all_documents(filters={"meta_field": ["value_0"]}, return_embedding=True)
|
|
assert len(documents) == 2
|
|
for doc in documents:
|
|
assert doc.meta["meta_field"] == "value_0"
|
|
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
|
|
|
documents = ds.get_all_documents(filters={"meta_field": ["value_0", "value_5"]}, return_embedding=True)
|
|
documents_with_value_0 = [doc for doc in documents if doc.meta["meta_field"] == "value_0"]
|
|
documents_with_value_5 = [doc for doc in documents if doc.meta["meta_field"] == "value_5"]
|
|
np.testing.assert_raises(
|
|
AssertionError,
|
|
np.testing.assert_array_equal,
|
|
documents_with_value_0[0].embedding,
|
|
documents_with_value_5[0].embedding,
|
|
)
|
|
|
|
doc = {
|
|
"content": "text_7",
|
|
"id": "7",
|
|
"meta_field": "value_7",
|
|
"embedding": retriever.embed_queries(queries=["a random string"])[0],
|
|
}
|
|
ds.write_documents([doc])
|
|
|
|
documents = []
|
|
for i in range(8, 11):
|
|
documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
|
ds.write_documents(documents)
|
|
|
|
doc_before_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
|
|
embedding_before_update = doc_before_update.embedding
|
|
|
|
ds.update_embeddings(retriever, batch_size=3, update_existing_embeddings=False)
|
|
doc_after_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
|
|
embedding_after_update = doc_after_update.embedding
|
|
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
|
|
|
# test updating with filters
|
|
if name == "faiss":
|
|
with pytest.raises(Exception):
|
|
ds.update_embeddings(retriever, update_existing_embeddings=True, filters={"meta_field": ["value"]})
|
|
else:
|
|
ds.update_embeddings(retriever, batch_size=3, filters={"meta_field": ["value_0", "value_1"]})
|
|
doc_after_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
|
|
embedding_after_update = doc_after_update.embedding
|
|
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
|
|
|
# test update all embeddings
|
|
ds.update_embeddings(retriever, batch_size=3, update_existing_embeddings=True)
|
|
assert ds.get_embedding_count() == 11
|
|
doc_after_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
|
|
embedding_after_update = doc_after_update.embedding
|
|
np.testing.assert_raises(
|
|
AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update
|
|
)
|
|
|
|
# test update embeddings for newly added docs
|
|
documents = []
|
|
for i in range(12, 15):
|
|
documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
|
ds.write_documents(documents)
|
|
|
|
ds.update_embeddings(retriever, batch_size=3, update_existing_embeddings=False)
|
|
assert ds.get_embedding_count() == 14
|
|
|
|
|
|
def test_update_embeddings_table_text_retriever(tmp_path):
|
|
documents = []
|
|
for i in range(3):
|
|
documents.append(
|
|
{"content": f"text_{i}", "id": f"pssg_{i}", "meta_field": f"value_text_{i}", "content_type": "text"}
|
|
)
|
|
documents.append(
|
|
{
|
|
"content": pd.DataFrame(columns=[f"col_{i}", f"col_{i+1}"], data=[[f"cell_{i}", f"cell_{i+1}"]]),
|
|
"id": f"table_{i}",
|
|
"meta_field": f"value_table_{i}",
|
|
"content_type": "table",
|
|
}
|
|
)
|
|
documents.append({"content": "text_0", "id": "pssg_4", "meta_field": "value_text_0", "content_type": "text"})
|
|
documents.append(
|
|
{
|
|
"content": pd.DataFrame(columns=["col_0", "col_1"], data=[["cell_0", "cell_1"]]),
|
|
"id": "table_4",
|
|
"meta_field": "value_table_0",
|
|
"content_type": "table",
|
|
}
|
|
)
|
|
|
|
with document_store("elasticsearch", documents, tmp_path, embedding_dim=512) as ds:
|
|
retriever = TableTextRetriever(
|
|
document_store=document_store,
|
|
query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
|
|
passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
|
|
table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
|
|
use_gpu=False,
|
|
)
|
|
ds.update_embeddings(retriever, batch_size=3)
|
|
documents = ds.get_all_documents(return_embedding=True)
|
|
assert len(documents) == 8
|
|
for doc in documents:
|
|
assert type(doc.embedding) is np.ndarray
|
|
|
|
# Check if Documents with same content (text) get same embedding
|
|
documents = ds.get_all_documents(filters={"meta_field": ["value_text_0"]}, return_embedding=True)
|
|
assert len(documents) == 2
|
|
for doc in documents:
|
|
assert doc.meta["meta_field"] == "value_text_0"
|
|
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
|
|
|
# Check if Documents with same content (table) get same embedding
|
|
documents = ds.get_all_documents(filters={"meta_field": ["value_table_0"]}, return_embedding=True)
|
|
assert len(documents) == 2
|
|
for doc in documents:
|
|
assert doc.meta["meta_field"] == "value_table_0"
|
|
np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
|
|
|
|
# Check if Documents with different content (text) get different embedding
|
|
documents = ds.get_all_documents(
|
|
filters={"meta_field": ["value_text_1", "value_text_2"]}, return_embedding=True
|
|
)
|
|
np.testing.assert_raises(
|
|
AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
|
|
)
|
|
|
|
# Check if Documents with different content (table) get different embeddings
|
|
documents = ds.get_all_documents(
|
|
filters={"meta_field": ["value_table_1", "value_table_2"]}, return_embedding=True
|
|
)
|
|
np.testing.assert_raises(
|
|
AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
|
|
)
|
|
|
|
# Check if Documents with different content (table + text) get different embeddings
|
|
documents = ds.get_all_documents(
|
|
filters={"meta_field": ["value_text_1", "value_table_1"]}, return_embedding=True
|
|
)
|
|
np.testing.assert_raises(
|
|
AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
|
|
)
|