Fix delete_all_documents for the SQLDocumentStore (#761)

This commit is contained in:
Tanay Soni 2021-01-22 14:39:24 +01:00 committed by GitHub
parent aee90c5df9
commit f0aa879a1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 52 additions and 25 deletions

View File

@ -388,7 +388,7 @@ In-memory document store
#### \_\_init\_\_
```python
| __init__(embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product")
| __init__(index: str = "document", label_index: str = "label", embedding_field: Optional[str] = "embedding", embedding_dim: int = 768, return_embedding: bool = False, similarity: str = "dot_product")
```
**Arguments**:

View File

@ -19,7 +19,15 @@ class InMemoryDocumentStore(BaseDocumentStore):
In-memory document store
"""
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product"):
def __init__(
self,
index: str = "document",
label_index: str = "label",
embedding_field: Optional[str] = "embedding",
embedding_dim: int = 768,
return_embedding: bool = False,
similarity: str = "dot_product",
):
"""
:param embedding_field: Name of field containing an embedding vector (Only needed when using a dense retriever (e.g. DensePassageRetriever, EmbeddingRetriever) on top)
:param return_embedding: To return document embedding
@ -27,12 +35,12 @@ class InMemoryDocumentStore(BaseDocumentStore):
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
"""
self.indexes: Dict[str, Dict] = defaultdict(dict)
self.index: str = "document"
self.label_index: str = "label"
self.embedding_field: str = embedding_field if embedding_field is not None else "embedding"
self.embedding_dim: int = 768
self.return_embedding: bool = return_embedding
self.similarity: str = similarity
self.index: str = index
self.label_index: str = label_index
self.embedding_field = embedding_field
self.embedding_dim = embedding_dim
self.return_embedding = return_embedding
self.similarity = similarity
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
"""

View File

@ -403,6 +403,7 @@ class SQLDocumentStore(BaseDocumentStore):
index = index or self.index
documents = self.session.query(DocumentORM).filter_by(index=index)
documents.delete(synchronize_session=False)
self.session.commit()
def _get_or_create(self, session, model, **kwargs):
instance = session.query(model).filter_by(**kwargs).first()

View File

@ -22,6 +22,24 @@ from haystack.reader.transformers import TransformersReader
from haystack.summarizer.transformers import TransformersSummarizer
def _sql_session_rollback(self, attr):
"""
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
errors where an intended operation is still in a transaction, but not committed to the database.
"""
method = object.__getattribute__(self, attr)
if callable(method):
try:
self.session.rollback()
except AttributeError:
pass
return method
SQLDocumentStore.__getattribute__ = _sql_session_rollback
def pytest_collection_modifyitems(items):
for item in items:
if "generator" in item.nodeid:
@ -246,9 +264,11 @@ def document_store(request, test_docs_xs):
def get_document_store(document_store_type, embedding_field="embedding"):
if document_store_type == "sql":
document_store = SQLDocumentStore(url="sqlite://")
document_store = SQLDocumentStore(url="sqlite://", index="haystack_test")
elif document_store_type == "memory":
document_store = InMemoryDocumentStore(return_embedding=True, embedding_field=embedding_field)
document_store = InMemoryDocumentStore(
return_embedding=True, embedding_field=embedding_field, index="haystack_test"
)
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
@ -261,6 +281,7 @@ def get_document_store(document_store_type, embedding_field="embedding"):
sql_url="sqlite://",
return_embedding=True,
embedding_field=embedding_field,
index="haystack_test",
)
return document_store
else:

View File

@ -1,6 +0,0 @@
from haystack import Document
def test_document_data_access():
doc = Document(text="test")
assert doc.text == "test"

View File

@ -218,19 +218,22 @@ def test_update_embeddings(document_store, retriever):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_delete_documents(document_store_with_docs):
assert len(document_store_with_docs.get_all_documents()) == 3
def test_delete_all_documents(document_store_with_docs):
assert len(document_store_with_docs.get_all_documents(index="haystack_test")) == 3
document_store_with_docs.delete_all_documents(index="haystack_test")
documents = document_store_with_docs.get_all_documents(index="haystack_test")
assert len(documents) == 0
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
def test_delete_documents_with_filters(document_store_with_docs):
document_store_with_docs.delete_all_documents(index="haystack_test", filters={"meta_field": ["test1", "test2"]})
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 1
assert documents[0].meta["meta_field"] == "test3"
document_store_with_docs.delete_all_documents(index="haystack_test")
documents = document_store_with_docs.get_all_documents()
assert len(documents) == 0
@pytest.mark.elasticsearch
def test_labels(document_store):
@ -394,8 +397,8 @@ def test_multilabel_no_answer(document_store):
@pytest.mark.elasticsearch
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql"], indirect=True)
def test_elasticsearch_update_meta(document_store):
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "sql"], indirect=True)
def test_update_meta(document_store):
documents = [
Document(
text="Doc1",