Add document update for SQL and FAISS Document Store (#584)

This commit is contained in:
Lalit Pagaria 2020-11-16 16:08:13 +01:00 committed by GitHub
parent 3e095ddd7d
commit 3f81c93f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 11 deletions

View File

@ -37,6 +37,8 @@ class FAISSDocumentStore(SQLDocumentStore):
faiss_index_factory_str: str = "Flat",
faiss_index: Optional[faiss.swigfaiss.Index] = None,
return_embedding: Optional[bool] = True,
update_existing_documents: bool = False,
index: str = "document",
**kwargs,
):
"""
@ -63,6 +65,11 @@ class FAISSDocumentStore(SQLDocumentStore):
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
or one with docs that you used in Haystack before and want to load again.
:param return_embedding: To return document embedding
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists.
:param index: Name of index in document store to use.
"""
self.vector_dim = vector_dim
@ -73,7 +80,11 @@ class FAISSDocumentStore(SQLDocumentStore):
self.index_buffer_size = index_buffer_size
self.return_embedding = return_embedding
super().__init__(url=sql_url)
super().__init__(
url=sql_url,
update_existing_documents=update_existing_documents,
index=index
)
def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric_type=faiss.METRIC_INNER_PRODUCT, **kwargs):
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
@ -99,12 +110,18 @@ class FAISSDocumentStore(SQLDocumentStore):
# vector index
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
# doc + metadata index
index = index or self.index
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
add_vectors = False if document_objects[0].embedding is None else True
if self.update_existing_documents and add_vectors:
logger.warning("You have enabled `update_existing_documents` feature and "
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
"Please call `update_embeddings` method to repopulate `faiss_index`")
for i in range(0, len(document_objects), self.index_buffer_size):
vector_id = self.faiss_index.ntotal
if add_vectors:
@ -134,6 +151,9 @@ class FAISSDocumentStore(SQLDocumentStore):
if not self.faiss_index:
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
# Faiss does not support update in existing index data so clear all existing data in it
self.faiss_index.reset()
index = index or self.index
documents = self.get_all_documents(index=index)

View File

@ -1,3 +1,4 @@
import logging
from typing import Any, Dict, Union, List, Optional
from uuid import uuid4
@ -10,6 +11,10 @@ from haystack.document_store.base import BaseDocumentStore
from haystack import Document, Label
from haystack.preprocessor.utils import eval_data_from_file
logger = logging.getLogger(__name__)
Base = declarative_base() # type: Any
@ -37,7 +42,7 @@ class MetaORM(ORMBase):
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
documents = relationship(DocumentORM, backref="Meta")
@ -45,7 +50,7 @@ class MetaORM(ORMBase):
class LabelORM(ORMBase):
__tablename__ = "label"
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
index = Column(String(100), nullable=False)
no_answer = Column(Boolean, nullable=False)
origin = Column(String(100), nullable=False)
@ -58,13 +63,30 @@ class LabelORM(ORMBase):
class SQLDocumentStore(BaseDocumentStore):
def __init__(self, url: str = "sqlite://", index="document"):
def __init__(
self,
url: str = "sqlite://",
index: str = "document",
label_index: str = "label",
update_existing_documents: bool = False,
):
"""
:param url: URL for SQL database as expected by SQLAlchemy. More info here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
:param index: The documents are scoped to an index attribute that can be used when writing, querying, or deleting documents.
This parameter sets the default value for document index.
:param label_index: The default value of index attribute for the labels.
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
documents. When set as True, any document with an existing ID gets updated.
If set to False, an error is raised if the document ID of the document being
added already exists. Using this parameter coud cause performance degradation for document insertion.
"""
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
self.index = index
self.label_index = "label"
self.label_index = label_index
self.update_existing_documents = update_existing_documents
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
documents = self.get_documents_by_id([id], index)
@ -132,8 +154,19 @@ class SQLDocumentStore(BaseDocumentStore):
vector_id = meta_fields.get("vector_id")
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
self.session.add(doc_orm)
self.session.commit()
if self.update_existing_documents:
# First old meta data cleaning is required
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
self.session.merge(doc_orm)
else:
self.session.add(doc_orm)
try:
self.session.commit()
except Exception as ex:
logger.error(f"Transaction rollback: {ex.__cause__}")
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
self.session.rollback()
raise ex
def write_labels(self, labels, index=None):

View File

@ -4,7 +4,6 @@ from elasticsearch import Elasticsearch
from haystack import Document, Label
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
@pytest.mark.elasticsearch
@ -85,6 +84,36 @@ def test_get_document_count(document_store):
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
@pytest.mark.parametrize("update_existing_documents", [True, False])
def test_update_existing_documents(document_store, update_existing_documents):
original_docs = [
{"text": "text1_orig", "id": "1", "meta_field_for_count": "a"},
]
updated_docs = [
{"text": "text1_new", "id": "1", "meta_field_for_count": "a"},
]
document_store.update_existing_documents = update_existing_documents
document_store.write_documents(original_docs)
assert document_store.get_document_count() == 1
if update_existing_documents:
document_store.write_documents(updated_docs)
else:
with pytest.raises(Exception):
document_store.write_documents(updated_docs)
stored_docs = document_store.get_all_documents()
assert len(stored_docs) == 1
if update_existing_documents:
assert stored_docs[0].text == updated_docs[0]["text"]
else:
assert stored_docs[0].text == original_docs[0]["text"]
@pytest.mark.elasticsearch
def test_write_document_meta(document_store):
documents = [
@ -112,9 +141,8 @@ def test_write_document_index(document_store):
document_store.write_documents([documents[0]], index="haystack_test_1")
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
document_store.write_documents([documents[1]], index="haystack_test_2")
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
document_store.write_documents([documents[1]], index="haystack_test_2")
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
assert len(document_store.get_all_documents()) == 0