feat: make DocumentWriter return the actual number of documents written (#6366)

* make DocumentWriter return the actual number of documents written

* add/improve tests
This commit is contained in:
Stefano Fiorucci 2023-11-21 15:54:25 +01:00 committed by GitHub
parent ec3558021e
commit 456902235a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 8 deletions

View File

@ -64,5 +64,5 @@ class DocumentWriter:
if policy is None:
policy = self.policy
self.document_store.write_documents(documents=documents, policy=policy)
return {"documents_written": len(documents)}
documents_written = self.document_store.write_documents(documents=documents, policy=policy)
return {"documents_written": documents_written}

View File

@ -1,11 +1,10 @@
from unittest.mock import MagicMock
import pytest
from haystack.preview import Document, DeserializationError
from haystack.preview.testing.factory import document_store_class
from haystack.preview.components.writers.document_writer import DocumentWriter
from haystack.preview.document_stores import DuplicatePolicy
from haystack.preview.document_stores.in_memory import InMemoryDocumentStore
class TestDocumentWriter:
@ -81,12 +80,27 @@ class TestDocumentWriter:
@pytest.mark.unit
def test_run(self):
mocked_document_store = MagicMock()
writer = DocumentWriter(mocked_document_store)
document_store = InMemoryDocumentStore()
writer = DocumentWriter(document_store)
documents = [
Document(content="This is the text of a document."),
Document(content="This is the text of another document."),
]
writer.run(documents=documents)
mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL)
result = writer.run(documents=documents)
assert result["documents_written"] == 2
@pytest.mark.unit
def test_run_skip_policy(self):
document_store = InMemoryDocumentStore()
writer = DocumentWriter(document_store, policy=DuplicatePolicy.SKIP)
documents = [
Document(content="This is the text of a document."),
Document(content="This is the text of another document."),
]
result = writer.run(documents=documents)
assert result["documents_written"] == 2
result = writer.run(documents=documents)
assert result["documents_written"] == 0