mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
Enable bulk operations on vector IDs for FAISSDocumentStore (#460)
This commit is contained in:
parent
029d1b75f2
commit
669c72d538
@ -125,8 +125,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
doc.embedding = embeddings[i]
|
||||
|
||||
phi = self._get_phi(documents)
|
||||
doc_meta_to_update = []
|
||||
|
||||
vector_id_map = {}
|
||||
for i in range(0, len(documents), self.index_buffer_size):
|
||||
vector_id = faiss_index.ntotal
|
||||
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
|
||||
@ -135,14 +135,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
faiss_index.add(hnsw_vectors)
|
||||
|
||||
for doc in documents[i: i + self.index_buffer_size]:
|
||||
meta = doc.meta or {}
|
||||
meta["vector_id"] = vector_id
|
||||
vector_id_map[doc.id] = vector_id
|
||||
vector_id += 1
|
||||
doc_meta_to_update.append((doc.id, meta))
|
||||
|
||||
for doc_id, meta in doc_meta_to_update:
|
||||
super(FAISSDocumentStore, self).update_document_meta(id=doc_id, meta=meta)
|
||||
|
||||
self.update_vector_ids(vector_id_map, index=index)
|
||||
self.faiss_index = faiss_index
|
||||
|
||||
def query_by_embedding(
|
||||
@ -159,9 +155,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
|
||||
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
||||
|
||||
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
|
||||
# sort the documents as per query results
|
||||
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
|
||||
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
|
||||
|
||||
# assign query score to each document
|
||||
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
||||
|
||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, sessionmaker
|
||||
from sqlalchemy.sql import case
|
||||
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack import Document, Label
|
||||
@ -25,6 +26,7 @@ class DocumentORM(ORMBase):
|
||||
|
||||
text = Column(String, nullable=False)
|
||||
index = Column(String, nullable=False)
|
||||
vector_id = Column(String, unique=True, nullable=True)
|
||||
|
||||
meta = relationship("MetaORM", backref="Document")
|
||||
|
||||
@ -75,6 +77,16 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
|
||||
return documents
|
||||
|
||||
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
|
||||
index = index or self.index
|
||||
results = self.session.query(DocumentORM).filter(
|
||||
DocumentORM.vector_id.in_(vector_ids),
|
||||
DocumentORM.index == index
|
||||
).all()
|
||||
sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id)) # type: ignore
|
||||
documents = [self._convert_sql_row_to_document(row) for row in sorted_results]
|
||||
return documents
|
||||
|
||||
def get_all_documents(
|
||||
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None
|
||||
) -> List[Document]:
|
||||
@ -116,8 +128,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
index = index or self.index
|
||||
for doc in document_objects:
|
||||
meta_fields = doc.meta or {}
|
||||
vector_id = meta_fields.get("vector_id")
|
||||
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
||||
doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index)
|
||||
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
||||
self.session.add(doc_orm)
|
||||
self.session.commit()
|
||||
|
||||
@ -141,6 +154,25 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.session.add(label_orm)
|
||||
self.session.commit()
|
||||
|
||||
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None):
|
||||
"""
|
||||
Update vector_ids for given document_ids.
|
||||
|
||||
:param vector_id_map: dict containing mapping of document_id -> vector_id.
|
||||
:param index: filter documents by the optional index attribute for documents in database.
|
||||
"""
|
||||
index = index or self.index
|
||||
self.session.query(DocumentORM).filter(
|
||||
DocumentORM.id.in_(vector_id_map),
|
||||
DocumentORM.index == index
|
||||
).update({
|
||||
DocumentORM.vector_id: case(
|
||||
vector_id_map,
|
||||
value=DocumentORM.id,
|
||||
)
|
||||
}, synchronize_session=False)
|
||||
self.session.commit()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
self.session.query(MetaORM).filter_by(document_id=id).delete()
|
||||
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
|
||||
@ -178,6 +210,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
text=row.text,
|
||||
meta={meta.name: meta.value for meta in row.meta}
|
||||
)
|
||||
if row.vector_id:
|
||||
document.meta["vector_id"] = row.vector_id # type: ignore
|
||||
return document
|
||||
|
||||
def _convert_sql_row_to_label(self, row) -> Label:
|
||||
|
||||
@ -271,24 +271,24 @@ def test_elasticsearch_update_meta(document_store):
|
||||
documents = [
|
||||
Document(
|
||||
text="Doc1",
|
||||
meta={"vector_id": "1", "meta_key": "1"}
|
||||
meta={"meta_key_1": "1", "meta_key_2": "1"}
|
||||
),
|
||||
Document(
|
||||
text="Doc2",
|
||||
meta={"vector_id": "2", "meta_key": "2"}
|
||||
meta={"meta_key_1": "2", "meta_key_2": "2"}
|
||||
),
|
||||
Document(
|
||||
text="Doc3",
|
||||
meta={"vector_id": "3", "meta_key": "3"}
|
||||
meta={"meta_key_1": "3", "meta_key_2": "3"}
|
||||
)
|
||||
]
|
||||
document_store.write_documents(documents)
|
||||
document_2 = document_store.get_all_documents(filters={"meta_key": ["2"]})[0]
|
||||
document_store.update_document_meta(document_2.id, meta={"vector_id": "99", "meta_key": "2"})
|
||||
document_2 = document_store.get_all_documents(filters={"meta_key_2": ["2"]})[0]
|
||||
document_store.update_document_meta(document_2.id, meta={"meta_key_1": "99", "meta_key_2": "2"})
|
||||
updated_document = document_store.get_document_by_id(document_2.id)
|
||||
assert len(updated_document.meta.keys()) == 2
|
||||
assert updated_document.meta["vector_id"] == "99"
|
||||
assert updated_document.meta["meta_key"] == "2"
|
||||
assert updated_document.meta["meta_key_1"] == "99"
|
||||
assert updated_document.meta["meta_key_2"] == "2"
|
||||
|
||||
|
||||
def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user