mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-13 16:43:44 +00:00
Enable bulk operations on vector IDs for FAISSDocumentStore (#460)
This commit is contained in:
parent
029d1b75f2
commit
669c72d538
@ -125,8 +125,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
doc.embedding = embeddings[i]
|
doc.embedding = embeddings[i]
|
||||||
|
|
||||||
phi = self._get_phi(documents)
|
phi = self._get_phi(documents)
|
||||||
doc_meta_to_update = []
|
|
||||||
|
|
||||||
|
vector_id_map = {}
|
||||||
for i in range(0, len(documents), self.index_buffer_size):
|
for i in range(0, len(documents), self.index_buffer_size):
|
||||||
vector_id = faiss_index.ntotal
|
vector_id = faiss_index.ntotal
|
||||||
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
|
embeddings = [doc.embedding for doc in documents[i: i + self.index_buffer_size]]
|
||||||
@ -135,14 +135,10 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
faiss_index.add(hnsw_vectors)
|
faiss_index.add(hnsw_vectors)
|
||||||
|
|
||||||
for doc in documents[i: i + self.index_buffer_size]:
|
for doc in documents[i: i + self.index_buffer_size]:
|
||||||
meta = doc.meta or {}
|
vector_id_map[doc.id] = vector_id
|
||||||
meta["vector_id"] = vector_id
|
|
||||||
vector_id += 1
|
vector_id += 1
|
||||||
doc_meta_to_update.append((doc.id, meta))
|
|
||||||
|
|
||||||
for doc_id, meta in doc_meta_to_update:
|
|
||||||
super(FAISSDocumentStore, self).update_document_meta(id=doc_id, meta=meta)
|
|
||||||
|
|
||||||
|
self.update_vector_ids(vector_id_map, index=index)
|
||||||
self.faiss_index = faiss_index
|
self.faiss_index = faiss_index
|
||||||
|
|
||||||
def query_by_embedding(
|
def query_by_embedding(
|
||||||
@ -159,9 +155,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
|
score_matrix, vector_id_matrix = self.faiss_index.search(hnsw_vectors, top_k)
|
||||||
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
vector_ids_for_query = [str(vector_id) for vector_id in vector_id_matrix[0] if vector_id != -1]
|
||||||
|
|
||||||
documents = self.get_all_documents(filters={"vector_id": vector_ids_for_query}, index=index)
|
documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index)
|
||||||
# sort the documents as per query results
|
|
||||||
documents = sorted(documents, key=lambda doc: vector_ids_for_query.index(doc.meta["vector_id"])) # type: ignore
|
|
||||||
|
|
||||||
# assign query score to each document
|
# assign query score to each document
|
||||||
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
scores_for_vector_ids: Dict[str, float] = {str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0])}
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from uuid import uuid4
|
|||||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean
|
from sqlalchemy import create_engine, Column, Integer, String, DateTime, func, ForeignKey, Boolean
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import relationship, sessionmaker
|
from sqlalchemy.orm import relationship, sessionmaker
|
||||||
|
from sqlalchemy.sql import case
|
||||||
|
|
||||||
from haystack.document_store.base import BaseDocumentStore
|
from haystack.document_store.base import BaseDocumentStore
|
||||||
from haystack import Document, Label
|
from haystack import Document, Label
|
||||||
@ -25,6 +26,7 @@ class DocumentORM(ORMBase):
|
|||||||
|
|
||||||
text = Column(String, nullable=False)
|
text = Column(String, nullable=False)
|
||||||
index = Column(String, nullable=False)
|
index = Column(String, nullable=False)
|
||||||
|
vector_id = Column(String, unique=True, nullable=True)
|
||||||
|
|
||||||
meta = relationship("MetaORM", backref="Document")
|
meta = relationship("MetaORM", backref="Document")
|
||||||
|
|
||||||
@ -75,6 +77,16 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
def get_documents_by_vector_ids(self, vector_ids: List[str], index: Optional[str] = None):
|
||||||
|
index = index or self.index
|
||||||
|
results = self.session.query(DocumentORM).filter(
|
||||||
|
DocumentORM.vector_id.in_(vector_ids),
|
||||||
|
DocumentORM.index == index
|
||||||
|
).all()
|
||||||
|
sorted_results = sorted(results, key=lambda doc: vector_ids.index(doc.vector_id)) # type: ignore
|
||||||
|
documents = [self._convert_sql_row_to_document(row) for row in sorted_results]
|
||||||
|
return documents
|
||||||
|
|
||||||
def get_all_documents(
|
def get_all_documents(
|
||||||
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None
|
self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
@ -116,8 +128,9 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
index = index or self.index
|
index = index or self.index
|
||||||
for doc in document_objects:
|
for doc in document_objects:
|
||||||
meta_fields = doc.meta or {}
|
meta_fields = doc.meta or {}
|
||||||
|
vector_id = meta_fields.get("vector_id")
|
||||||
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
||||||
doc_orm = DocumentORM(id=doc.id, text=doc.text, meta=meta_orms, index=index)
|
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
||||||
self.session.add(doc_orm)
|
self.session.add(doc_orm)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
@ -141,6 +154,25 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
self.session.add(label_orm)
|
self.session.add(label_orm)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
|
def update_vector_ids(self, vector_id_map: Dict[str, str], index: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Update vector_ids for given document_ids.
|
||||||
|
|
||||||
|
:param vector_id_map: dict containing mapping of document_id -> vector_id.
|
||||||
|
:param index: filter documents by the optional index attribute for documents in database.
|
||||||
|
"""
|
||||||
|
index = index or self.index
|
||||||
|
self.session.query(DocumentORM).filter(
|
||||||
|
DocumentORM.id.in_(vector_id_map),
|
||||||
|
DocumentORM.index == index
|
||||||
|
).update({
|
||||||
|
DocumentORM.vector_id: case(
|
||||||
|
vector_id_map,
|
||||||
|
value=DocumentORM.id,
|
||||||
|
)
|
||||||
|
}, synchronize_session=False)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||||
self.session.query(MetaORM).filter_by(document_id=id).delete()
|
self.session.query(MetaORM).filter_by(document_id=id).delete()
|
||||||
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
|
meta_orms = [MetaORM(name=key, value=value, document_id=id) for key, value in meta.items()]
|
||||||
@ -178,6 +210,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
text=row.text,
|
text=row.text,
|
||||||
meta={meta.name: meta.value for meta in row.meta}
|
meta={meta.name: meta.value for meta in row.meta}
|
||||||
)
|
)
|
||||||
|
if row.vector_id:
|
||||||
|
document.meta["vector_id"] = row.vector_id # type: ignore
|
||||||
return document
|
return document
|
||||||
|
|
||||||
def _convert_sql_row_to_label(self, row) -> Label:
|
def _convert_sql_row_to_label(self, row) -> Label:
|
||||||
|
|||||||
@ -271,24 +271,24 @@ def test_elasticsearch_update_meta(document_store):
|
|||||||
documents = [
|
documents = [
|
||||||
Document(
|
Document(
|
||||||
text="Doc1",
|
text="Doc1",
|
||||||
meta={"vector_id": "1", "meta_key": "1"}
|
meta={"meta_key_1": "1", "meta_key_2": "1"}
|
||||||
),
|
),
|
||||||
Document(
|
Document(
|
||||||
text="Doc2",
|
text="Doc2",
|
||||||
meta={"vector_id": "2", "meta_key": "2"}
|
meta={"meta_key_1": "2", "meta_key_2": "2"}
|
||||||
),
|
),
|
||||||
Document(
|
Document(
|
||||||
text="Doc3",
|
text="Doc3",
|
||||||
meta={"vector_id": "3", "meta_key": "3"}
|
meta={"meta_key_1": "3", "meta_key_2": "3"}
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
document_store.write_documents(documents)
|
document_store.write_documents(documents)
|
||||||
document_2 = document_store.get_all_documents(filters={"meta_key": ["2"]})[0]
|
document_2 = document_store.get_all_documents(filters={"meta_key_2": ["2"]})[0]
|
||||||
document_store.update_document_meta(document_2.id, meta={"vector_id": "99", "meta_key": "2"})
|
document_store.update_document_meta(document_2.id, meta={"meta_key_1": "99", "meta_key_2": "2"})
|
||||||
updated_document = document_store.get_document_by_id(document_2.id)
|
updated_document = document_store.get_document_by_id(document_2.id)
|
||||||
assert len(updated_document.meta.keys()) == 2
|
assert len(updated_document.meta.keys()) == 2
|
||||||
assert updated_document.meta["vector_id"] == "99"
|
assert updated_document.meta["meta_key_1"] == "99"
|
||||||
assert updated_document.meta["meta_key"] == "2"
|
assert updated_document.meta["meta_key_2"] == "2"
|
||||||
|
|
||||||
|
|
||||||
def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
def test_elasticsearch_custom_fields(elasticsearch_fixture):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user