diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index 422f000eb..d1e4b5c20 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -125,8 +125,8 @@ class FAISSDocumentStore(SQLDocumentStore): doc.embedding = embeddings[i] phi = self._get_phi(documents) - doc_meta_to_update = [] + vector_id_map = {} for i in range(0, len(documents), self.index_buffer_size): vector_id = faiss_index.ntotal embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]] @@ -135,14 +135,10 @@ class FAISSDocumentStore(SQLDocumentStore): faiss_index.add(hnsw_vectors) for doc in documents[i: i + self.index_buffer_size]: - meta = doc.meta or {} - meta["vector_id"] = vector_id + vector_id_map[doc.id] = vector_id vector_id += 1 - doc_meta_to_update.append((doc.id, meta)) - - for doc_id, meta in doc_meta_to_update: - super(FAISSDocumentStore, self).update_document_meta(id=doc_id, meta=meta) + self.update_vector_ids(vector_id_map, index=index) self.faiss_index = faiss_index def query_by_embedding( @@ -159,9 +155,7 @@ class FAISSDocumentStore(SQLDocumentStore): score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k) vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1] - documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index) - # sort the documents as per query results - documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore + documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index) # assign query score to each document scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])} diff --git a/haystack/document_store/sql.py b/haystack/document_store/sql.py index a6d0c2afc..38566615c 100644 --- a/haystack/document_store/sql.py +++ b/haystack/document_store/sql.py @@ -4,6 +4,7 @@ from uuid import uuid4 from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker +from sqlalchemy.sql import case from haystack.document_store.base import BaseDocumentStore from haystack import Document, Label @@ -25,6 +26,7 @@ class DocumentORM(ORMBase): text = Column(String, nullable=False) index = Column(String, nullable=False) + vector_id = Column(String, unique=True, nullable=True) meta = relationship("MetaORM", backref="Document") @@ -75,6 +77,16 @@ class SQLDocumentStore(BaseDocumentStore): return documents + def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None): + index = index or self.index + results = self.session.query(DocumentORM).filter( + DocumentORM.vector_id.in_(vector_ids), + DocumentORM.index == index + ).all() + sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id)) # type: ignore + documents = [self._convert_sql_row_to_document(row) for row in sorted_results] + return documents + def get_all_documents( self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None ) -> List[Document]: @@ -116,8 +128,9 @@ class SQLDocumentStore(BaseDocumentStore): index = index or self.index for doc in document_objects: meta_fields = doc.meta or {} + vector_id = meta_fields.get("vector_id") meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()] - doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index) + doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index) self.session.add(doc_orm) self.session.commit() @@ -141,6 +154,25 @@ class SQLDocumentStore(BaseDocumentStore): self.session.add(label_orm) self.session.commit() + def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None): + """ + Update vector_ids for given document_ids. + + :param vector_id_map: dict containing mapping of document_id -> vector_id. + :param index: filter documents by the optional index attribute for documents in database. + """ + index = index or self.index + self.session.query(DocumentORM).filter( + DocumentORM.id.in_(vector_id_map), + DocumentORM.index == index + ).update({ + DocumentORM.vector_id: case( + vector_id_map, + value=DocumentORM.id, + ) + }, synchronize_session=False) + self.session.commit() + def update_document_meta(self, id: str, meta: Dict[str, str]): self.session.query(MetaORM).filter_by(document_id=id).delete() meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()] @@ -178,6 +210,8 @@ class SQLDocumentStore(BaseDocumentStore): text=row.text, meta={meta.name: meta.value for meta in row.meta} ) + if row.vector_id: + document.meta["vector_id"] = row.vector_id # type: ignore return document def _convert_sql_row_to_label(self, row) -> Label: diff --git a/test/test_db.py b/test/test_db.py index 4e1d611c4..049313156 100644 --- a/test/test_db.py +++ b/test/test_db.py @@ -271,24 +271,24 @@ def test_elasticsearch_update_meta(document_store): documents = [ Document( text="Doc1", - meta={"vector_id": "1", "meta_key": "1"} + meta={"meta_key_1": "1", "meta_key_2": "1"} ), Document( text="Doc2", - meta={"vector_id": "2", "meta_key": "2"} + meta={"meta_key_1": "2", "meta_key_2": "2"} ), Document( text="Doc3", - meta={"vector_id": "3", "meta_key": "3"} + meta={"meta_key_1": "3", "meta_key_2": "3"} ) ] document_store.write_documents(documents) - document_2 = document_store.get_all_documents(filters={"meta_key": ["2"]})[0] - document_store.update_document_meta(document_2.id, meta={"vector_id": "99", "meta_key": "2"}) + document_2 = document_store.get_all_documents(filters={"meta_key_2": ["2"]})[0] + document_store.update_document_meta(document_2.id, meta={"meta_key_1": "99", "meta_key_2": "2"}) updated_document = document_store.get_document_by_id(document_2.id) assert len(updated_document.meta.keys()) == 2 - assert updated_document.meta["vector_id"] == "99" - assert updated_document.meta["meta_key"] == "2" + assert updated_document.meta["meta_key_1"] == "99" + assert updated_document.meta["meta_key_2"] == "2" def test_elasticsearch_custom_fields(elasticsearch_fixture):