fix: allow same vector_id in different indexes for SQL-based Document stores (#3383)

* fix_multiple_indexes

* improve test names
This commit is contained in:
Stefano Fiorucci 2022-10-14 09:55:56 +02:00 committed by GitHub
parent ba30971d8d
commit 7290196c32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 57 additions and 4 deletions

View File

@ -342,7 +342,7 @@ class FAISSDocumentStore(SQLDocumentStore):
return return
logger.info("Updating embeddings for %s docs...", document_count) logger.info("Updating embeddings for %s docs...", document_count)
vector_id = sum(index.ntotal for index in self.faiss_indexes.values()) vector_id = self.faiss_indexes[index].ntotal
result = self._query( result = self._query(
index=index, index=index,

View File

@ -19,6 +19,7 @@ try:
text, text,
JSON, JSON,
ForeignKeyConstraint, ForeignKeyConstraint,
UniqueConstraint,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker, validates from sqlalchemy.orm import relationship, sessionmaker, validates
@ -52,10 +53,12 @@ class DocumentORM(ORMBase):
content_type = Column(Text, nullable=True) content_type = Column(Text, nullable=True)
# primary key in combination with id to allow the same doc in different indices # primary key in combination with id to allow the same doc in different indices
index = Column(String(100), nullable=False, primary_key=True) index = Column(String(100), nullable=False, primary_key=True)
vector_id = Column(String(100), unique=True, nullable=True) vector_id = Column(String(100), nullable=True)
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata # speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined") meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
__table_args__ = (UniqueConstraint("index", "vector_id", name="index_vector_id_uc"),)
class MetaDocumentORM(ORMBase): class MetaDocumentORM(ORMBase):
__tablename__ = "meta_document" __tablename__ = "meta_document"

View File

@ -983,7 +983,7 @@ def setup_postgres():
with engine.connect() as connection: with engine.connect() as connection:
try: try:
connection.execute(text("DROP SCHEMA public CASCADE")) connection.execute(text("DROP SCHEMA IF EXISTS public CASCADE"))
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
connection.execute(text("CREATE SCHEMA public;")) connection.execute(text("CREATE SCHEMA public;"))

View File

@ -489,7 +489,7 @@ def test_write_document_meta(document_store: BaseDocumentStore):
@pytest.mark.parametrize("document_store", ["sql"], indirect=True) @pytest.mark.parametrize("document_store", ["sql"], indirect=True)
def test_write_document_sql_invalid_meta(document_store: BaseDocumentStore): def test_sql_write_document_invalid_meta(document_store: BaseDocumentStore):
documents = [ documents = [
{ {
"content": "dict_with_invalid_meta", "content": "dict_with_invalid_meta",
@ -512,6 +512,23 @@ def test_write_document_sql_invalid_meta(document_store: BaseDocumentStore):
assert document_store.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"} assert document_store.get_document_by_id("2").meta == {"name": "filename2", "valid_meta_field": "test2"}
@pytest.mark.parametrize("document_store", ["sql"], indirect=True)
def test_sql_write_different_documents_same_vector_id(document_store: BaseDocumentStore):
doc1 = {"content": "content 1", "name": "doc1", "id": "1", "vector_id": "vector_id"}
doc2 = {"content": "content 2", "name": "doc2", "id": "2", "vector_id": "vector_id"}
document_store.write_documents([doc1], index="index1")
documents_in_index1 = document_store.get_all_documents(index="index1")
assert len(documents_in_index1) == 1
document_store.write_documents([doc2], index="index2")
documents_in_index2 = document_store.get_all_documents(index="index2")
assert len(documents_in_index2) == 1
document_store.write_documents([doc1], index="index3")
with pytest.raises(Exception, match=r"(?i)unique"):
document_store.write_documents([doc2], index="index3")
def test_write_document_index(document_store: BaseDocumentStore): def test_write_document_index(document_store: BaseDocumentStore):
document_store.delete_index("haystack_test_one") document_store.delete_index("haystack_test_one")
document_store.delete_index("haystack_test_two") document_store.delete_index("haystack_test_two")

View File

@ -196,6 +196,39 @@ def test_faiss_write_docs(document_store, index_buffer_size, batch_size):
assert np.allclose(original_doc["embedding"] / np.linalg.norm(original_doc["embedding"]), stored_emb, rtol=0.01) assert np.allclose(original_doc["embedding"] / np.linalg.norm(original_doc["embedding"]), stored_emb, rtol=0.01)
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_write_docs_different_indexes(document_store):
document_store.write_documents(DOCUMENTS, index="index1")
document_store.write_documents(DOCUMENTS, index="index2")
docs_from_index1 = document_store.get_all_documents(index="index1", return_embedding=False)
assert len(docs_from_index1) == len(DOCUMENTS)
assert {int(doc.meta["vector_id"]) for doc in docs_from_index1} == set(range(0, 6))
docs_from_index2 = document_store.get_all_documents(index="index2", return_embedding=False)
assert len(docs_from_index2) == len(DOCUMENTS)
assert {int(doc.meta["vector_id"]) for doc in docs_from_index2} == set(range(0, 6))
@pytest.mark.parametrize("document_store", ["faiss"], indirect=True)
def test_faiss_update_docs_different_indexes(document_store):
retriever = MockDenseRetriever(document_store=document_store)
document_store.write_documents(DOCUMENTS, index="index1")
document_store.write_documents(DOCUMENTS, index="index2")
document_store.update_embeddings(retriever=retriever, update_existing_embeddings=True, index="index1")
document_store.update_embeddings(retriever=retriever, update_existing_embeddings=True, index="index2")
docs_from_index1 = document_store.get_all_documents(index="index1", return_embedding=False)
assert len(docs_from_index1) == len(DOCUMENTS)
assert {int(doc.meta["vector_id"]) for doc in docs_from_index1} == set(range(0, 6))
docs_from_index2 = document_store.get_all_documents(index="index2", return_embedding=False)
assert len(docs_from_index2) == len(DOCUMENTS)
assert {int(doc.meta["vector_id"]) for doc in docs_from_index2} == set(range(0, 6))
@pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner") @pytest.mark.skipif(sys.platform in ["win32", "cygwin"], reason="Test with tmp_path not working on windows runner")
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"]) @pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
def test_faiss_retrieving(index_factory, tmp_path): def test_faiss_retrieving(index_factory, tmp_path):