mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 03:39:31 +00:00 
			
		
		
		
	* add e2e tests * move tests to their own module * add e2e workflow * pylint * remove from job * fix index field name * skip test on sql * removed unused code * fix embedding tests * adjust test for pinecone * adjust assertions to the new documents * bad copypasta * test * fix tests * fix tests * fix test * fix tests * pylint * update milvus version * remove debug * move graphdb tests under e2e
		
			
				
	
	
		
			167 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import pytest
 | 
						|
import numpy as np
 | 
						|
import pandas as pd
 | 
						|
 | 
						|
from haystack.nodes import EmbeddingRetriever, TableTextRetriever
 | 
						|
 | 
						|
from .conftest import document_store
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.parametrize("name", ["elasticsearch", "faiss", "memory", "milvus"])
 | 
						|
def test_update_embeddings(name, tmp_path):
 | 
						|
    documents = []
 | 
						|
    for i in range(6):
 | 
						|
        documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
 | 
						|
    documents.append({"content": "text_0", "id": "6", "meta_field": "value_0"})
 | 
						|
 | 
						|
    with document_store(name, documents, tmp_path) as ds:
 | 
						|
        retriever = EmbeddingRetriever(document_store=ds, embedding_model="deepset/sentence_bert", use_gpu=False)
 | 
						|
 | 
						|
        ds.update_embeddings(retriever, batch_size=3)
 | 
						|
        documents = ds.get_all_documents(return_embedding=True)
 | 
						|
        assert len(documents) == 7
 | 
						|
        for doc in documents:
 | 
						|
            assert type(doc.embedding) is np.ndarray
 | 
						|
 | 
						|
        documents = ds.get_all_documents(filters={"meta_field": ["value_0"]}, return_embedding=True)
 | 
						|
        assert len(documents) == 2
 | 
						|
        for doc in documents:
 | 
						|
            assert doc.meta["meta_field"] == "value_0"
 | 
						|
        np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
 | 
						|
 | 
						|
        documents = ds.get_all_documents(filters={"meta_field": ["value_0", "value_5"]}, return_embedding=True)
 | 
						|
        documents_with_value_0 = [doc for doc in documents if doc.meta["meta_field"] == "value_0"]
 | 
						|
        documents_with_value_5 = [doc for doc in documents if doc.meta["meta_field"] == "value_5"]
 | 
						|
        np.testing.assert_raises(
 | 
						|
            AssertionError,
 | 
						|
            np.testing.assert_array_equal,
 | 
						|
            documents_with_value_0[0].embedding,
 | 
						|
            documents_with_value_5[0].embedding,
 | 
						|
        )
 | 
						|
 | 
						|
        doc = {
 | 
						|
            "content": "text_7",
 | 
						|
            "id": "7",
 | 
						|
            "meta_field": "value_7",
 | 
						|
            "embedding": retriever.embed_queries(queries=["a random string"])[0],
 | 
						|
        }
 | 
						|
        ds.write_documents([doc])
 | 
						|
 | 
						|
        documents = []
 | 
						|
        for i in range(8, 11):
 | 
						|
            documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
 | 
						|
        ds.write_documents(documents)
 | 
						|
 | 
						|
        doc_before_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
 | 
						|
        embedding_before_update = doc_before_update.embedding
 | 
						|
 | 
						|
        ds.update_embeddings(retriever, batch_size=3, update_existing_embeddings=False)
 | 
						|
        doc_after_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
 | 
						|
        embedding_after_update = doc_after_update.embedding
 | 
						|
        np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
 | 
						|
 | 
						|
        # test updating with filters
 | 
						|
        if name == "faiss":
 | 
						|
            with pytest.raises(Exception):
 | 
						|
                ds.update_embeddings(retriever, update_existing_embeddings=True, filters={"meta_field": ["value"]})
 | 
						|
        else:
 | 
						|
            ds.update_embeddings(retriever, batch_size=3, filters={"meta_field": ["value_0", "value_1"]})
 | 
						|
            doc_after_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
 | 
						|
            embedding_after_update = doc_after_update.embedding
 | 
						|
            np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
 | 
						|
 | 
						|
        # test update all embeddings
 | 
						|
        ds.update_embeddings(retriever, batch_size=3, update_existing_embeddings=True)
 | 
						|
        assert ds.get_embedding_count() == 11
 | 
						|
        doc_after_update = ds.get_all_documents(filters={"meta_field": ["value_7"]})[0]
 | 
						|
        embedding_after_update = doc_after_update.embedding
 | 
						|
        np.testing.assert_raises(
 | 
						|
            AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update
 | 
						|
        )
 | 
						|
 | 
						|
        # test update embeddings for newly added docs
 | 
						|
        documents = []
 | 
						|
        for i in range(12, 15):
 | 
						|
            documents.append({"content": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
 | 
						|
        ds.write_documents(documents)
 | 
						|
 | 
						|
        ds.update_embeddings(retriever, batch_size=3, update_existing_embeddings=False)
 | 
						|
        assert ds.get_embedding_count() == 14
 | 
						|
 | 
						|
 | 
						|
def test_update_embeddings_table_text_retriever(tmp_path):
 | 
						|
    documents = []
 | 
						|
    for i in range(3):
 | 
						|
        documents.append(
 | 
						|
            {"content": f"text_{i}", "id": f"pssg_{i}", "meta_field": f"value_text_{i}", "content_type": "text"}
 | 
						|
        )
 | 
						|
        documents.append(
 | 
						|
            {
 | 
						|
                "content": pd.DataFrame(columns=[f"col_{i}", f"col_{i+1}"], data=[[f"cell_{i}", f"cell_{i+1}"]]),
 | 
						|
                "id": f"table_{i}",
 | 
						|
                "meta_field": f"value_table_{i}",
 | 
						|
                "content_type": "table",
 | 
						|
            }
 | 
						|
        )
 | 
						|
    documents.append({"content": "text_0", "id": "pssg_4", "meta_field": "value_text_0", "content_type": "text"})
 | 
						|
    documents.append(
 | 
						|
        {
 | 
						|
            "content": pd.DataFrame(columns=["col_0", "col_1"], data=[["cell_0", "cell_1"]]),
 | 
						|
            "id": "table_4",
 | 
						|
            "meta_field": "value_table_0",
 | 
						|
            "content_type": "table",
 | 
						|
        }
 | 
						|
    )
 | 
						|
 | 
						|
    with document_store("elasticsearch", documents, tmp_path, embedding_dim=512) as ds:
 | 
						|
        retriever = TableTextRetriever(
 | 
						|
            document_store=document_store,
 | 
						|
            query_embedding_model="deepset/bert-small-mm_retrieval-question_encoder",
 | 
						|
            passage_embedding_model="deepset/bert-small-mm_retrieval-passage_encoder",
 | 
						|
            table_embedding_model="deepset/bert-small-mm_retrieval-table_encoder",
 | 
						|
            use_gpu=False,
 | 
						|
        )
 | 
						|
        ds.update_embeddings(retriever, batch_size=3)
 | 
						|
        documents = ds.get_all_documents(return_embedding=True)
 | 
						|
        assert len(documents) == 8
 | 
						|
        for doc in documents:
 | 
						|
            assert type(doc.embedding) is np.ndarray
 | 
						|
 | 
						|
        # Check if Documents with same content (text) get same embedding
 | 
						|
        documents = ds.get_all_documents(filters={"meta_field": ["value_text_0"]}, return_embedding=True)
 | 
						|
        assert len(documents) == 2
 | 
						|
        for doc in documents:
 | 
						|
            assert doc.meta["meta_field"] == "value_text_0"
 | 
						|
        np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
 | 
						|
 | 
						|
        # Check if Documents with same content (table) get same embedding
 | 
						|
        documents = ds.get_all_documents(filters={"meta_field": ["value_table_0"]}, return_embedding=True)
 | 
						|
        assert len(documents) == 2
 | 
						|
        for doc in documents:
 | 
						|
            assert doc.meta["meta_field"] == "value_table_0"
 | 
						|
        np.testing.assert_array_almost_equal(documents[0].embedding, documents[1].embedding, decimal=4)
 | 
						|
 | 
						|
        # Check if Documents wih different content (text) get different embedding
 | 
						|
        documents = ds.get_all_documents(
 | 
						|
            filters={"meta_field": ["value_text_1", "value_text_2"]}, return_embedding=True
 | 
						|
        )
 | 
						|
        np.testing.assert_raises(
 | 
						|
            AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
 | 
						|
        )
 | 
						|
 | 
						|
        # Check if Documents with different content (table) get different embeddings
 | 
						|
        documents = ds.get_all_documents(
 | 
						|
            filters={"meta_field": ["value_table_1", "value_table_2"]}, return_embedding=True
 | 
						|
        )
 | 
						|
        np.testing.assert_raises(
 | 
						|
            AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
 | 
						|
        )
 | 
						|
 | 
						|
        # Check if Documents with different content (table + text) get different embeddings
 | 
						|
        documents = ds.get_all_documents(
 | 
						|
            filters={"meta_field": ["value_text_1", "value_table_1"]}, return_embedding=True
 | 
						|
        )
 | 
						|
        np.testing.assert_raises(
 | 
						|
            AssertionError, np.testing.assert_array_equal, documents[0].embedding, documents[1].embedding
 | 
						|
        )
 |