diff --git a/haystack/preview/components/writers/document_writer.py b/haystack/preview/components/writers/document_writer.py index 5b3d2f4ac..2ce45afde 100644 --- a/haystack/preview/components/writers/document_writer.py +++ b/haystack/preview/components/writers/document_writer.py @@ -64,5 +64,5 @@ class DocumentWriter: if policy is None: policy = self.policy - self.document_store.write_documents(documents=documents, policy=policy) - return {"documents_written": len(documents)} + documents_written = self.document_store.write_documents(documents=documents, policy=policy) + return {"documents_written": documents_written} diff --git a/test/preview/components/writers/test_document_writer.py b/test/preview/components/writers/test_document_writer.py index b535616c9..ed5b9a411 100644 --- a/test/preview/components/writers/test_document_writer.py +++ b/test/preview/components/writers/test_document_writer.py @@ -1,11 +1,10 @@ -from unittest.mock import MagicMock - import pytest from haystack.preview import Document, DeserializationError from haystack.preview.testing.factory import document_store_class from haystack.preview.components.writers.document_writer import DocumentWriter from haystack.preview.document_stores import DuplicatePolicy +from haystack.preview.document_stores.in_memory import InMemoryDocumentStore class TestDocumentWriter: @@ -81,12 +80,27 @@ class TestDocumentWriter: @pytest.mark.unit def test_run(self): - mocked_document_store = MagicMock() - writer = DocumentWriter(mocked_document_store) + document_store = InMemoryDocumentStore() + writer = DocumentWriter(document_store) documents = [ Document(content="This is the text of a document."), Document(content="This is the text of another document."), ] - writer.run(documents=documents) - mocked_document_store.write_documents.assert_called_once_with(documents=documents, policy=DuplicatePolicy.FAIL) + result = writer.run(documents=documents) + assert result["documents_written"] == 2 + + @pytest.mark.unit + def test_run_skip_policy(self): + document_store = InMemoryDocumentStore() + writer = DocumentWriter(document_store, policy=DuplicatePolicy.SKIP) + documents = [ + Document(content="This is the text of a document."), + Document(content="This is the text of another document."), + ] + + result = writer.run(documents=documents) + assert result["documents_written"] == 2 + + result = writer.run(documents=documents) + assert result["documents_written"] == 0