mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
Fix delete_all_documents for the SQLDocumentStore (#761)
This commit is contained in:
parent
aee90c5df9
commit
f0aa879a1c
@ -388,7 +388,7 @@ In-memory document store
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product")
|
||||
| __init__(index: str = "document", label_index: str = "label", embedding_field: Optional[str] = "embedding", embedding_dim: int = 768, return_embedding: bool = False, similarity: str = "dot_product")
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
|
||||
@ -19,7 +19,15 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
In-memory document store
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_field: Optional[str] = "embedding", return_embedding: bool = False, similarity="dot_product"):
|
||||
def __init__(
|
||||
self,
|
||||
index: str = "document",
|
||||
label_index: str = "label",
|
||||
embedding_field: Optional[str] = "embedding",
|
||||
embedding_dim: int = 768,
|
||||
return_embedding: bool = False,
|
||||
similarity: str = "dot_product",
|
||||
):
|
||||
"""
|
||||
:param embedding_field: Name of field containing an embedding vector (Only needed when using a dense retriever (e.g. DensePassageRetriever, EmbeddingRetriever) on top)
|
||||
:param return_embedding: To return document embedding
|
||||
@ -27,12 +35,12 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence BERT model.
|
||||
"""
|
||||
self.indexes: Dict[str, Dict] = defaultdict(dict)
|
||||
self.index: str = "document"
|
||||
self.label_index: str = "label"
|
||||
self.embedding_field: str = embedding_field if embedding_field is not None else "embedding"
|
||||
self.embedding_dim: int = 768
|
||||
self.return_embedding: bool = return_embedding
|
||||
self.similarity: str = similarity
|
||||
self.index: str = index
|
||||
self.label_index: str = label_index
|
||||
self.embedding_field = embedding_field
|
||||
self.embedding_dim = embedding_dim
|
||||
self.return_embedding = return_embedding
|
||||
self.similarity = similarity
|
||||
|
||||
def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None):
|
||||
"""
|
||||
|
||||
@ -403,6 +403,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
index = index or self.index
|
||||
documents = self.session.query(DocumentORM).filter_by(index=index)
|
||||
documents.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
|
||||
def _get_or_create(self, session, model, **kwargs):
|
||||
instance = session.query(model).filter_by(**kwargs).first()
|
||||
|
||||
@ -22,6 +22,24 @@ from haystack.reader.transformers import TransformersReader
|
||||
from haystack.summarizer.transformers import TransformersSummarizer
|
||||
|
||||
|
||||
def _sql_session_rollback(self, attr):
|
||||
"""
|
||||
Inject SQLDocumentStore at runtime to do a session rollback each time it is called. This allows to catch
|
||||
errors where an intended operation is still in a transaction, but not committed to the database.
|
||||
"""
|
||||
method = object.__getattribute__(self, attr)
|
||||
if callable(method):
|
||||
try:
|
||||
self.session.rollback()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return method
|
||||
|
||||
|
||||
SQLDocumentStore.__getattribute__ = _sql_session_rollback
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items):
|
||||
for item in items:
|
||||
if "generator" in item.nodeid:
|
||||
@ -246,9 +264,11 @@ def document_store(request, test_docs_xs):
|
||||
|
||||
def get_document_store(document_store_type, embedding_field="embedding"):
|
||||
if document_store_type == "sql":
|
||||
document_store = SQLDocumentStore(url="sqlite://")
|
||||
document_store = SQLDocumentStore(url="sqlite://", index="haystack_test")
|
||||
elif document_store_type == "memory":
|
||||
document_store = InMemoryDocumentStore(return_embedding=True, embedding_field=embedding_field)
|
||||
document_store = InMemoryDocumentStore(
|
||||
return_embedding=True, embedding_field=embedding_field, index="haystack_test"
|
||||
)
|
||||
elif document_store_type == "elasticsearch":
|
||||
# make sure we start from a fresh index
|
||||
client = Elasticsearch()
|
||||
@ -261,6 +281,7 @@ def get_document_store(document_store_type, embedding_field="embedding"):
|
||||
sql_url="sqlite://",
|
||||
return_embedding=True,
|
||||
embedding_field=embedding_field,
|
||||
index="haystack_test",
|
||||
)
|
||||
return document_store
|
||||
else:
|
||||
|
||||
@ -1,6 +0,0 @@
|
||||
from haystack import Document
|
||||
|
||||
|
||||
def test_document_data_access():
|
||||
doc = Document(text="test")
|
||||
assert doc.text == "test"
|
||||
@ -218,19 +218,22 @@ def test_update_embeddings(document_store, retriever):
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_delete_documents(document_store_with_docs):
|
||||
assert len(document_store_with_docs.get_all_documents()) == 3
|
||||
def test_delete_all_documents(document_store_with_docs):
|
||||
assert len(document_store_with_docs.get_all_documents(index="haystack_test")) == 3
|
||||
|
||||
document_store_with_docs.delete_all_documents(index="haystack_test")
|
||||
documents = document_store_with_docs.get_all_documents(index="haystack_test")
|
||||
assert len(documents) == 0
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_delete_documents_with_filters(document_store_with_docs):
|
||||
document_store_with_docs.delete_all_documents(index="haystack_test", filters={"meta_field": ["test1", "test2"]})
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
assert len(documents) == 1
|
||||
assert documents[0].meta["meta_field"] == "test3"
|
||||
|
||||
document_store_with_docs.delete_all_documents(index="haystack_test")
|
||||
documents = document_store_with_docs.get_all_documents()
|
||||
assert len(documents) == 0
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_labels(document_store):
|
||||
@ -394,8 +397,8 @@ def test_multilabel_no_answer(document_store):
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "sql"], indirect=True)
|
||||
def test_elasticsearch_update_meta(document_store):
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "sql"], indirect=True)
|
||||
def test_update_meta(document_store):
|
||||
documents = [
|
||||
Document(
|
||||
text="Doc1",
|
||||
Loading…
x
Reference in New Issue
Block a user