mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-10 06:43:58 +00:00
Add document update for SQL and FAISS Document Store (#584)
This commit is contained in:
parent
3e095ddd7d
commit
3f81c93f36
@ -37,6 +37,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
faiss_index_factory_str: str = "Flat",
|
faiss_index_factory_str: str = "Flat",
|
||||||
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
faiss_index: Optional[faiss.swigfaiss.Index] = None,
|
||||||
return_embedding: Optional[bool] = True,
|
return_embedding: Optional[bool] = True,
|
||||||
|
update_existing_documents: bool = False,
|
||||||
|
index: str = "document",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -63,6 +65,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
|
:param faiss_index: Pass an existing FAISS Index, i.e. an empty one that you configured manually
|
||||||
or one with docs that you used in Haystack before and want to load again.
|
or one with docs that you used in Haystack before and want to load again.
|
||||||
:param return_embedding: To return document embedding
|
:param return_embedding: To return document embedding
|
||||||
|
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
|
||||||
|
documents. When set as True, any document with an existing ID gets updated.
|
||||||
|
If set to False, an error is raised if the document ID of the document being
|
||||||
|
added already exists.
|
||||||
|
:param index: Name of index in document store to use.
|
||||||
"""
|
"""
|
||||||
self.vector_dim = vector_dim
|
self.vector_dim = vector_dim
|
||||||
|
|
||||||
@ -73,7 +80,11 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
|
|
||||||
self.index_buffer_size = index_buffer_size
|
self.index_buffer_size = index_buffer_size
|
||||||
self.return_embedding = return_embedding
|
self.return_embedding = return_embedding
|
||||||
super().__init__(url=sql_url)
|
super().__init__(
|
||||||
|
url=sql_url,
|
||||||
|
update_existing_documents=update_existing_documents,
|
||||||
|
index=index
|
||||||
|
)
|
||||||
|
|
||||||
def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric_type=faiss.METRIC_INNER_PRODUCT, **kwargs):
|
def _create_new_index(self, vector_dim: int, index_factory: str = "Flat", metric_type=faiss.METRIC_INNER_PRODUCT, **kwargs):
|
||||||
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
|
if index_factory == "HNSW" and metric_type == faiss.METRIC_INNER_PRODUCT:
|
||||||
@ -99,12 +110,18 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
# vector index
|
# vector index
|
||||||
if not self.faiss_index:
|
if not self.faiss_index:
|
||||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||||
|
|
||||||
# doc + metadata index
|
# doc + metadata index
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents]
|
||||||
|
|
||||||
add_vectors = False if document_objects[0].embedding is None else True
|
add_vectors = False if document_objects[0].embedding is None else True
|
||||||
|
|
||||||
|
if self.update_existing_documents and add_vectors:
|
||||||
|
logger.warning("You have enabled `update_existing_documents` feature and "
|
||||||
|
"`FAISSDocumentStore` does not support update in existing `faiss_index`.\n"
|
||||||
|
"Please call `update_embeddings` method to repopulate `faiss_index`")
|
||||||
|
|
||||||
for i in range(0, len(document_objects), self.index_buffer_size):
|
for i in range(0, len(document_objects), self.index_buffer_size):
|
||||||
vector_id = self.faiss_index.ntotal
|
vector_id = self.faiss_index.ntotal
|
||||||
if add_vectors:
|
if add_vectors:
|
||||||
@ -134,6 +151,9 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
if not self.faiss_index:
|
if not self.faiss_index:
|
||||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||||
|
|
||||||
|
# Faiss does not support update in existing index data so clear all existing data in it
|
||||||
|
self.faiss_index.reset()
|
||||||
|
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
documents = self.get_all_documents(index=index)
|
documents = self.get_all_documents(index=index)
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from typing import Any, Dict, Union, List, Optional
|
from typing import Any, Dict, Union, List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -10,6 +11,10 @@ from haystack.document_store.base import BaseDocumentStore
|
|||||||
from haystack import Document, Label
|
from haystack import Document, Label
|
||||||
from haystack.preprocessor.utils import eval_data_from_file
|
from haystack.preprocessor.utils import eval_data_from_file
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
Base = declarative_base() # type: Any
|
Base = declarative_base() # type: Any
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +42,7 @@ class MetaORM(ORMBase):
|
|||||||
|
|
||||||
name = Column(String(100), index=True)
|
name = Column(String(100), index=True)
|
||||||
value = Column(String(1000), index=True)
|
value = Column(String(1000), index=True)
|
||||||
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
|
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
|
||||||
|
|
||||||
documents = relationship(DocumentORM, backref="Meta")
|
documents = relationship(DocumentORM, backref="Meta")
|
||||||
|
|
||||||
@ -45,7 +50,7 @@ class MetaORM(ORMBase):
|
|||||||
class LabelORM(ORMBase):
|
class LabelORM(ORMBase):
|
||||||
__tablename__ = "label"
|
__tablename__ = "label"
|
||||||
|
|
||||||
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE"), nullable=False)
|
document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
|
||||||
index = Column(String(100), nullable=False)
|
index = Column(String(100), nullable=False)
|
||||||
no_answer = Column(Boolean, nullable=False)
|
no_answer = Column(Boolean, nullable=False)
|
||||||
origin = Column(String(100), nullable=False)
|
origin = Column(String(100), nullable=False)
|
||||||
@ -58,13 +63,30 @@ class LabelORM(ORMBase):
|
|||||||
|
|
||||||
|
|
||||||
class SQLDocumentStore(BaseDocumentStore):
|
class SQLDocumentStore(BaseDocumentStore):
|
||||||
def __init__(self, url: str = "sqlite://", index="document"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str = "sqlite://",
|
||||||
|
index: str = "document",
|
||||||
|
label_index: str = "label",
|
||||||
|
update_existing_documents: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param url: URL for SQL database as expected by SQLAlchemy. More info here: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls
|
||||||
|
:param index: The documents are scoped to an index attribute that can be used when writing, querying, or deleting documents.
|
||||||
|
This parameter sets the default value for document index.
|
||||||
|
:param label_index: The default value of index attribute for the labels.
|
||||||
|
:param update_existing_documents: Whether to update any existing documents with the same ID when adding
|
||||||
|
documents. When set as True, any document with an existing ID gets updated.
|
||||||
|
If set to False, an error is raised if the document ID of the document being
|
||||||
|
added already exists. Using this parameter coud cause performance degradation for document insertion.
|
||||||
|
"""
|
||||||
engine = create_engine(url)
|
engine = create_engine(url)
|
||||||
ORMBase.metadata.create_all(engine)
|
ORMBase.metadata.create_all(engine)
|
||||||
Session = sessionmaker(bind=engine)
|
Session = sessionmaker(bind=engine)
|
||||||
self.session = Session()
|
self.session = Session()
|
||||||
self.index = index
|
self.index = index
|
||||||
self.label_index = "label"
|
self.label_index = label_index
|
||||||
|
self.update_existing_documents = update_existing_documents
|
||||||
|
|
||||||
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
def get_document_by_id(self, id: str, index: Optional[str] = None) -> Optional[Document]:
|
||||||
documents = self.get_documents_by_id([id], index)
|
documents = self.get_documents_by_id([id], index)
|
||||||
@ -132,8 +154,19 @@ class SQLDocumentStore(BaseDocumentStore):
|
|||||||
vector_id = meta_fields.get("vector_id")
|
vector_id = meta_fields.get("vector_id")
|
||||||
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
meta_orms = [MetaORM(name=key, value=value) for key, value in meta_fields.items()]
|
||||||
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
doc_orm = DocumentORM(id=doc.id, text=doc.text, vector_id=vector_id, meta=meta_orms, index=index)
|
||||||
self.session.add(doc_orm)
|
if self.update_existing_documents:
|
||||||
self.session.commit()
|
# First old meta data cleaning is required
|
||||||
|
self.session.query(MetaORM).filter_by(document_id=doc.id).delete()
|
||||||
|
self.session.merge(doc_orm)
|
||||||
|
else:
|
||||||
|
self.session.add(doc_orm)
|
||||||
|
try:
|
||||||
|
self.session.commit()
|
||||||
|
except Exception as ex:
|
||||||
|
logger.error(f"Transaction rollback: {ex.__cause__}")
|
||||||
|
# Rollback is important here otherwise self.session will be in inconsistent state and next call will fail
|
||||||
|
self.session.rollback()
|
||||||
|
raise ex
|
||||||
|
|
||||||
def write_labels(self, labels, index=None):
|
def write_labels(self, labels, index=None):
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from elasticsearch import Elasticsearch
|
|||||||
|
|
||||||
from haystack import Document, Label
|
from haystack import Document, Label
|
||||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||||
from haystack.document_store.faiss import FAISSDocumentStore
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@ -85,6 +84,36 @@ def test_get_document_count(document_store):
|
|||||||
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
assert document_store.get_document_count(filters={"meta_field_for_count": ["b"]}) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql", "faiss"], indirect=True)
|
||||||
|
@pytest.mark.parametrize("update_existing_documents", [True, False])
|
||||||
|
def test_update_existing_documents(document_store, update_existing_documents):
|
||||||
|
original_docs = [
|
||||||
|
{"text": "text1_orig", "id": "1", "meta_field_for_count": "a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
updated_docs = [
|
||||||
|
{"text": "text1_new", "id": "1", "meta_field_for_count": "a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
document_store.update_existing_documents = update_existing_documents
|
||||||
|
document_store.write_documents(original_docs)
|
||||||
|
assert document_store.get_document_count() == 1
|
||||||
|
|
||||||
|
if update_existing_documents:
|
||||||
|
document_store.write_documents(updated_docs)
|
||||||
|
else:
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
document_store.write_documents(updated_docs)
|
||||||
|
|
||||||
|
stored_docs = document_store.get_all_documents()
|
||||||
|
assert len(stored_docs) == 1
|
||||||
|
if update_existing_documents:
|
||||||
|
assert stored_docs[0].text == updated_docs[0]["text"]
|
||||||
|
else:
|
||||||
|
assert stored_docs[0].text == original_docs[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
def test_write_document_meta(document_store):
|
def test_write_document_meta(document_store):
|
||||||
documents = [
|
documents = [
|
||||||
@ -112,9 +141,8 @@ def test_write_document_index(document_store):
|
|||||||
document_store.write_documents([documents[0]], index="haystack_test_1")
|
document_store.write_documents([documents[0]], index="haystack_test_1")
|
||||||
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
||||||
|
|
||||||
if not isinstance(document_store, FAISSDocumentStore): # addition of more documents is not supported in FAISS
|
document_store.write_documents([documents[1]], index="haystack_test_2")
|
||||||
document_store.write_documents([documents[1]], index="haystack_test_2")
|
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
|
||||||
assert len(document_store.get_all_documents(index="haystack_test_2")) == 1
|
|
||||||
|
|
||||||
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
assert len(document_store.get_all_documents(index="haystack_test_1")) == 1
|
||||||
assert len(document_store.get_all_documents()) == 0
|
assert len(document_store.get_all_documents()) == 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user