Enable bulk operations on vector IDs for FAISSDocumentStore (#460)

This commit is contained in:
Tanay Soni 2020-10-02 14:43:25 +02:00 committed by GitHub
parent 029d1b75f2
commit 669c72d538
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 46 additions and 18 deletions

View File

@ -125,8 +125,8 @@ class FAISSDocumentStore(SQLDocumentStore):
doc.embedding = embeddings[i]
phi = self._get_phi(documents)
doc_meta_to_update = []
vector_id_map = {}
for i in range(0, len(documents), self.index_buffer_size):
vector_id = faiss_index.ntotal
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
@ -135,14 +135,10 @@ class FAISSDocumentStore(SQLDocumentStore):
faiss_index.add(hnsw_vectors)
for doc in documents[i: i + self.index_buffer_size]:
meta = doc.meta or {}
meta["vector_id"] = vector_id
vector_id_map[doc.id] = vector_id
vector_id += 1
doc_meta_to_update.append((doc.id, meta))
for doc_id, meta in doc_meta_to_update:
super(FAISSDocumentStore, self).update_document_meta(id=doc_id, meta=meta)
self.update_vector_ids(vector_id_map, index=index)
self.faiss_index = faiss_index
def query_by_embedding(
@ -159,9 +155,7 @@ class FAISSDocumentStore(SQLDocumentStore):
score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
# sort the documents as per query results
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
# assign query score to each document
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}

View File

@ -4,6 +4,7 @@ from uuid import uuid4
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case
from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
@ -25,6 +26,7 @@ class DocumentORM(ORMBase):
text = Column(String, nullable=False)
index = Column(String, nullable=False)
vector_id = Column(String, unique=True, nullable=True)
meta = relationship("MetaORM", backref="Document")
@ -75,6 +77,16 @@ class SQLDocumentStore(BaseDocumentStore):
return documents
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
index = index or self.index
results = self.session.query(DocumentORM).filter(
DocumentORM.vector_id.in_(vector_ids),
DocumentORM.index == index
).all()
sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id)) # type: ignore
documents = [self._convert_sql_row_to_document(row) for row in sorted_results]
return documents
def get_all_documents(
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None
) -> List[Document]:
@ -116,8 +128,9 @@ class SQLDocumentStore(BaseDocumentStore):
index = index or self.index
for doc in document_objects:
meta_fields = doc.meta or {}
vector_id = meta_fields.get("vector_id")
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index)
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
self.session.add(doc_orm)
self.session.commit()
@ -141,6 +154,25 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.add(label_orm)
self.session.commit()
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None):
"""
Update vector_ids for given document_ids.
:param vector_id_map: dict containing mapping of document_id -> vector_id.
:param index: filter documents by the optional index attribute for documents in database.
"""
index = index or self.index
self.session.query(DocumentORM).filter(
DocumentORM.id.in_(vector_id_map),
DocumentORM.index == index
).update({
DocumentORM.vector_id: case(
vector_id_map,
value=DocumentORM.id,
)
}, synchronize_session=False)
self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str]):
self.session.query(MetaORM).filter_by(document_id=id).delete()
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
@ -178,6 +210,8 @@ class SQLDocumentStore(BaseDocumentStore):
text=row.text,
meta={meta.name: meta.value for meta in row.meta}
)
if row.vector_id:
document.meta["vector_id"] = row.vector_id # type: ignore
return document
def _convert_sql_row_to_label(self, row) -> Label:

View File

@ -271,24 +271,24 @@ def test_elasticsearch_update_meta(document_store):
documents = [
Document(
text="Doc1",
meta={"vector_id": "1", "meta_key": "1"}
meta={"meta_key_1": "1", "meta_key_2": "1"}
),
Document(
text="Doc2",
meta={"vector_id": "2", "meta_key": "2"}
meta={"meta_key_1": "2", "meta_key_2": "2"}
),
Document(
text="Doc3",
meta={"vector_id": "3", "meta_key": "3"}
meta={"meta_key_1": "3", "meta_key_2": "3"}
)
]
document_store.write_documents(documents)
document_2 = document_store.get_all_documents(filters={"meta_key": ["2"]})[0]
document_store.update_document_meta(document_2.id, meta={"vector_id": "99", "meta_key": "2"})
document_2 = document_store.get_all_documents(filters={"meta_key_2": ["2"]})[0]
document_store.update_document_meta(document_2.id, meta={"meta_key_1": "99", "meta_key_2": "2"})
updated_document = document_store.get_document_by_id(document_2.id)
assert len(updated_document.meta.keys()) == 2
assert updated_document.meta["vector_id"] == "99"
assert updated_document.meta["meta_key"] == "2"
assert updated_document.meta["meta_key_1"] == "99"
assert updated_document.meta["meta_key_2"] == "2"
def test_elasticsearch_custom_fields(elasticsearch_fixture):