diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index 74f73870c..4ecb0b1de 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -342,7 +342,7 @@ class FAISSDocumentStore(SQLDocumentStore): return logger.info("Updating embeddings for %s docs...", document_count) - vector_id = sum(index.ntotal for index in self.faiss_indexes.values()) + vector_id = self.faiss_indexes[index].ntotal result = self._query( index=index, diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 61e547e0e..4f88c842a 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -19,6 +19,7 @@ try: text, JSON, ForeignKeyConstraint, + UniqueConstraint, ) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker, validates @@ -52,10 +53,12 @@ class DocumentORM(ORMBase): content_type = Column(Text, nullable=True) # primary key in combination with id to allow the same doc in different indices index = Column(String(100), nullable=False, primary_key=True) - vector_id = Column(String(100), unique=True, nullable=True) + vector_id = Column(String(100), nullable=True) # speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined") + __table_args__ = (UniqueConstraint("index", "vector_id", name="index_vector_id_uc"),) + class MetaDocumentORM(ORMBase): __tablename__ = "meta_document" diff --git a/test/conftest.py b/test/conftest.py index 303c00a5a..1ac5fe7e4 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -983,7 +983,7 @@ def setup_postgres(): with engine.connect() as connection: try: - connection.execute(text("DROP SCHEMA public CASCADE")) + connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE")) except Exception as e: logging.error(e) connection.execute(text("CREATE SCHEMA public;")) diff --git a/test/document_stores/test_document_store.py b/test/document_stores/test_document_store.py index 5f039820e..8734f9dcd 100644 --- a/test/document_stores/test_document_store.py +++ b/test/document_stores/test_document_store.py @@ -489,7 +489,7 @@ def test_write_document_meta(document_store: BaseDocumentStore): @pytest.mark.parametrize("document_store", ["sql"], indirect=True) -def test_write_document_sql_invalid_meta(document_store: BaseDocumentStore): +def test_sql_write_document_invalid_meta(document_store: BaseDocumentStore): documents = [ { "content": "dict_with_invalid_meta", @@ -512,6 +512,23 @@ def test_write_document_sql_invalid_meta(document_store: BaseDocumentStore): assert document_store.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"} +@pytest.mark.parametrize("document_store", ["sql"], indirect=True) +def test_sql_write_different_documents_same_vector_id(document_store: BaseDocumentStore): + doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"} + doc2 = {"content": "content 2", "name": "doc2", "id": "2", "vector_id": "vector_id"} + + document_store.write_documents([doc1], index="index1") + documents_in_index1 = document_store.get_all_documents(index="index1") + assert len(documents_in_index1) == 1 + document_store.write_documents([doc2], index="index2") + documents_in_index2 = document_store.get_all_documents(index="index2") + assert len(documents_in_index2) == 1 + + document_store.write_documents([doc1], index="index3") + with pytest.raises(Exception, match=r"(?i)unique"): + document_store.write_documents([doc2], index="index3") + + def test_write_document_index(document_store: BaseDocumentStore): document_store.delete_index("haystack_test_one") document_store.delete_index("haystack_test_two") diff --git a/test/document_stores/test_faiss.py b/test/document_stores/test_faiss.py index 87302af3c..37df068bd 100644 --- a/test/document_stores/test_faiss.py +++ b/test/document_stores/test_faiss.py @@ -196,6 +196,39 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size): assert np.allclose(original_doc["embedding"] / np.linalg.norm(original_doc["embedding"]), stored_emb, rtol=0.01) +@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) +def test_faiss_write_docs_different_indexes(document_store): + document_store.write_documents(DOCUMENTS, index="index1") + document_store.write_documents(DOCUMENTS, index="index2") + + docs_from_index1 = document_store.get_all_documents(index="index1", return_embedding=False) + assert len(docs_from_index1) == len(DOCUMENTS) + assert {int(doc.meta["vector_id"]) for doc in docs_from_index1} == set(range(0, 6)) + + docs_from_index2 = document_store.get_all_documents(index="index2", return_embedding=False) + assert len(docs_from_index2) == len(DOCUMENTS) + assert {int(doc.meta["vector_id"]) for doc in docs_from_index2} == set(range(0, 6)) + + +@pytest.mark.parametrize("document_store", ["faiss"], indirect=True) +def test_faiss_update_docs_different_indexes(document_store): + retriever = MockDenseRetriever(document_store=document_store) + + document_store.write_documents(DOCUMENTS, index="index1") + document_store.write_documents(DOCUMENTS, index="index2") + + document_store.update_embeddings(retriever=retriever, update_existing_embeddings=True, index="index1") + document_store.update_embeddings(retriever=retriever, update_existing_embeddings=True, index="index2") + + docs_from_index1 = document_store.get_all_documents(index="index1", return_embedding=False) + assert len(docs_from_index1) == len(DOCUMENTS) + assert {int(doc.meta["vector_id"]) for doc in docs_from_index1} == set(range(0, 6)) + + docs_from_index2 = document_store.get_all_documents(index="index2", return_embedding=False) + assert len(docs_from_index2) == len(DOCUMENTS) + assert {int(doc.meta["vector_id"]) for doc in docs_from_index2} == set(range(0, 6)) + + @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") @pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"]) def test_faiss_retrieving(index_factory, tmp_path):