mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 10:19:23 +00:00
Add document update for SQL and FAISS Document Store (#584)
This commit is contained in:
parent
3e095ddd7d
commit
3f81c93f36
@ -37,6 +37,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
faiss_index_factory_str: str = "Flat",
|
||||
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
||||
return_embedding: Optional[bool] = True,
|
||||
update_existing_documents: bool = False,
|
||||
index: str = "document",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -63,6 +65,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
|
||||
or one with docs that you used in Haystack before and want to load again.
|
||||
:param return_embedding: To return document embedding
|
||||
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
|
||||
documents. When set as True, any document with an existing ID gets updated.
|
||||
If set to False, an error is raised if the document ID of the document being
|
||||
added already exists.
|
||||
:param index: Name of index in document store to use.
|
||||
"""
|
||||
self.vector_dim = vector_dim
|
||||
|
||||
@ -73,7 +80,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
self.index_buffer_size = index_buffer_size
|
||||
self.return_embedding = return_embedding
|
||||
super().__init__(url=sql_url)
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
update_existing_documents=update_existing_documents,
|
||||
index=index
|
||||
)
|
||||
|
||||
def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric_type=faiss.METRIC_INNER_PRODUCT, **kwargs):
|
||||
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||
@ -99,12 +110,18 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
# vector index
|
||||
if not self.faiss_index:
|
||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||
|
||||
# doc + metadata index
|
||||
index = index or self.index
|
||||
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||
|
||||
add_vectors = False if document_objects[0].embedding is None else True
|
||||
|
||||
if self.update_existing_documents and add_vectors:
|
||||
logger.warning("You have enabled `update_existing_documents` feature and "
|
||||
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
|
||||
"Please call `update_embeddings` method to repopulate `faiss_index`")
|
||||
|
||||
for i in range(0, len(document_objects), self.index_buffer_size):
|
||||
vector_id = self.faiss_index.ntotal
|
||||
if add_vectors:
|
||||
@ -134,6 +151,9 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
if not self.faiss_index:
|
||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||
|
||||
# Faiss does not support update in existing index data so clear all existing data in it
|
||||
self.faiss_index.reset()
|
||||
|
||||
index = index or self.index
|
||||
documents = self.get_all_documents(index=index)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
@ -10,6 +11,10 @@ from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack import Document, Label
|
||||
from haystack.preprocessor.utils import eval_data_from_file
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Base = declarative_base() # type: Any
|
||||
|
||||
|
||||
@ -37,7 +42,7 @@ class MetaORM(ORMBase):
|
||||
|
||||
name = Column(String(100), index=True)
|
||||
value = Column(String(1000), index=True)
|
||||
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
|
||||
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
|
||||
|
||||
documents = relationship(DocumentORM, backref="Meta")
|
||||
|
||||
@ -45,7 +50,7 @@ class MetaORM(ORMBase):
|
||||
class LabelORM(ORMBase):
|
||||
__tablename__ = "label"
|
||||
|
||||
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
|
||||
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
|
||||
index = Column(String(100), nullable=False)
|
||||
no_answer = Column(Boolean, nullable=False)
|
||||
origin = Column(String(100), nullable=False)
|
||||
@ -58,13 +63,30 @@ class LabelORM(ORMBase):
|
||||
|
||||
|
||||
class SQLDocumentStore(BaseDocumentStore):
|
||||
def __init__(self, url: str = "sqlite://", index="document"):
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "sqlite://",
|
||||
index: str = "document",
|
||||
label_index: str = "label",
|
||||
update_existing_documents: bool = False,
|
||||
):
|
||||
"""
|
||||
:param url: URL for SQL database as expected by SQLAlchemy. More info here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
|
||||
:param index: The documents are scoped to an index attribute that can be used when writing, querying, or deleting documents.
|
||||
This parameter sets the default value for document index.
|
||||
:param label_index: The default value of index attribute for the labels.
|
||||
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
|
||||
documents. When set as True, any document with an existing ID gets updated.
|
||||
If set to False, an error is raised if the document ID of the document being
|
||||
added already exists. Using this parameter coud cause performance degradation for document insertion.
|
||||
"""
|
||||
engine = create_engine(url)
|
||||
ORMBase.metadata.create_all(engine)
|
||||
Session = sessionmaker(bind=engine)
|
||||
self.session = Session()
|
||||
self.index = index
|
||||
self.label_index = "label"
|
||||
self.label_index = label_index
|
||||
self.update_existing_documents = update_existing_documents
|
||||
|
||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||
documents = self.get_documents_by_id([id], index)
|
||||
@ -132,8 +154,19 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
vector_id = meta_fields.get("vector_id")
|
||||
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
||||
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
||||
self.session.add(doc_orm)
|
||||
self.session.commit()
|
||||
if self.update_existing_documents:
|
||||
# First old meta data cleaning is required
|
||||
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
|
||||
self.session.merge(doc_orm)
|
||||
else:
|
||||
self.session.add(doc_orm)
|
||||
try:
|
||||
self.session.commit()
|
||||
except Exception as ex:
|
||||
logger.error(f"Transaction rollback: {ex.__cause__}")
|
||||
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
|
||||
self.session.rollback()
|
||||
raise ex
|
||||
|
||||
def write_labels(self, labels, index=None):
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from elasticsearch import Elasticsearch
|
||||
|
||||
from haystack import Document, Label
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@ -85,6 +84,36 @@ def test_get_document_count(document_store):
|
||||
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
||||
def test_update_existing_documents(document_store, update_existing_documents):
|
||||
original_docs = [
|
||||
{"text": "text1_orig", "id": "1", "meta_field_for_count": "a"},
|
||||
]
|
||||
|
||||
updated_docs = [
|
||||
{"text": "text1_new", "id": "1", "meta_field_for_count": "a"},
|
||||
]
|
||||
|
||||
document_store.update_existing_documents = update_existing_documents
|
||||
document_store.write_documents(original_docs)
|
||||
assert document_store.get_document_count() == 1
|
||||
|
||||
if update_existing_documents:
|
||||
document_store.write_documents(updated_docs)
|
||||
else:
|
||||
with pytest.raises(Exception):
|
||||
document_store.write_documents(updated_docs)
|
||||
|
||||
stored_docs = document_store.get_all_documents()
|
||||
assert len(stored_docs) == 1
|
||||
if update_existing_documents:
|
||||
assert stored_docs[0].text == updated_docs[0]["text"]
|
||||
else:
|
||||
assert stored_docs[0].text == original_docs[0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_write_document_meta(document_store):
|
||||
documents = [
|
||||
@ -112,9 +141,8 @@ def test_write_document_index(document_store):
|
||||
document_store.write_documents([documents[0]], index="haystack_test_1")
|
||||
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
||||
|
||||
if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
|
||||
document_store.write_documents([documents[1]], index="haystack_test_2")
|
||||
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
|
||||
document_store.write_documents([documents[1]], index="haystack_test_2")
|
||||
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
|
||||
|
||||
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
||||
assert len(document_store.get_all_documents()) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user