From b9a34dfebf3973d05116e0356748d73c27d83f3c Mon Sep 17 00:00:00 2001 From: Abdelrahman Kaseb Date: Tue, 12 Aug 2025 16:20:44 +0300 Subject: [PATCH] Fix: prevent in-place mutation of documents in Document Classifiers and Extractors (#9703) * modify Documents Classifiers and Extractors to not make in-place changes * Add e2e test for NER * Add unit test for NER * fixes + refinements --------- Co-authored-by: anakin87 --- e2e/pipelines/test_named_entity_extractor.py | 14 +++++- .../document_language_classifier.py | 10 ++-- .../zero_shot_document_classifier.py | 7 ++- .../extractors/llm_metadata_extractor.py | 22 +++++---- .../extractors/named_entity_extractor.py | 8 ++-- pyproject.toml | 2 +- ...actor-inplace-change-8a59fe68a1b87e4b.yaml | 5 ++ .../test_document_language_classifier.py | 2 + .../test_zero_shot_document_classifier.py | 4 ++ .../extractors/test_llm_metadata_extractor.py | 2 + .../extractors/test_named_entity_extractor.py | 48 ++++++++++++++++++- 11 files changed, 102 insertions(+), 22 deletions(-) create mode 100644 releasenotes/notes/fix-classifier-extractor-inplace-change-8a59fe68a1b87e4b.yaml diff --git a/e2e/pipelines/test_named_entity_extractor.py b/e2e/pipelines/test_named_entity_extractor.py index 93388322a..64ef3c564 100644 --- a/e2e/pipelines/test_named_entity_extractor.py +++ b/e2e/pipelines/test_named_entity_extractor.py @@ -2,6 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 +# Note: We only test the Spacy backend in this module, which is executed in the e2e environment. +# We don't test Spacy in test/components/extractors/test_named_entity_extractor.py, which is executed in the +# test environment. Spacy is not installed in the test environment to keep the CI fast. + import os import pytest @@ -113,7 +117,15 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, monkey def _extract_and_check_predictions(extractor, texts, expected, batch_size): docs = [Document(content=text) for text in texts] outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"] - assert all(id(a) == id(b) for a, b in zip(docs, outputs)) + for original_doc, output_doc in zip(docs, outputs): + # we don't modify documents in place + assert original_doc is not output_doc + + # apart from meta, the documents should be identical + output_doc_dict = output_doc.to_dict(flatten=False) + output_doc_dict.pop("meta", None) + assert original_doc.to_dict() == output_doc_dict + predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs] _check_predictions(predicted, expected) diff --git a/haystack/components/classifiers/document_language_classifier.py b/haystack/components/classifiers/document_language_classifier.py index 208065258..dc88c9e0c 100644 --- a/haystack/components/classifiers/document_language_classifier.py +++ b/haystack/components/classifiers/document_language_classifier.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from dataclasses import replace from typing import Optional from haystack import Document, component, logging @@ -89,14 +90,17 @@ class DocumentLanguageClassifier: output: dict[str, list[Document]] = {language: [] for language in self.languages} output["unmatched"] = [] + new_documents = [] for document in documents: detected_language = self._detect_language(document) + new_meta = {**document.meta} if detected_language in self.languages: - document.meta["language"] = detected_language + new_meta["language"] = detected_language else: - document.meta["language"] = "unmatched" + new_meta["language"] = "unmatched" + new_documents.append(replace(document, meta=new_meta)) - return {"documents": documents} + return {"documents": new_documents} def _detect_language(self, document: Document) -> Optional[str]: language = None diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py index 2be587789..94794e816 100644 --- a/haystack/components/classifiers/zero_shot_document_classifier.py +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from dataclasses import replace from typing import Any, Optional from haystack import Document, component, default_from_dict, default_to_dict @@ -232,12 +233,14 @@ class TransformersZeroShotDocumentClassifier: predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size) + new_documents = [] for prediction, document in zip(predictions, documents): formatted_prediction = { "label": prediction["labels"][0], "score": prediction["scores"][0], "details": dict(zip(prediction["labels"], prediction["scores"])), } - document.meta["classification"] = formatted_prediction + new_meta = {**document.meta, "classification": formatted_prediction} + new_documents.append(replace(document, meta=new_meta)) - return {"documents": documents} + return {"documents": new_documents} diff --git a/haystack/components/extractors/llm_metadata_extractor.py b/haystack/components/extractors/llm_metadata_extractor.py index 40a5773f5..eac70b340 100644 --- a/haystack/components/extractors/llm_metadata_extractor.py +++ b/haystack/components/extractors/llm_metadata_extractor.py @@ -5,6 +5,7 @@ import copy import json from concurrent.futures import ThreadPoolExecutor +from dataclasses import replace from typing import Any, Optional, Union from jinja2 import meta @@ -318,24 +319,25 @@ class LLMMetadataExtractor: successful_documents = [] failed_documents = [] for document, result in zip(documents, results): + new_meta = {**document.meta} if "error" in result: - document.meta["metadata_extraction_error"] = result["error"] - document.meta["metadata_extraction_response"] = None - failed_documents.append(document) + new_meta["metadata_extraction_error"] = result["error"] + new_meta["metadata_extraction_response"] = None + failed_documents.append(replace(document, meta=new_meta)) continue parsed_metadata = self._extract_metadata(result["replies"][0].text) if "error" in parsed_metadata: - document.meta["metadata_extraction_error"] = parsed_metadata["error"] - document.meta["metadata_extraction_response"] = result["replies"][0] - failed_documents.append(document) + new_meta["metadata_extraction_error"] = parsed_metadata["error"] + new_meta["metadata_extraction_response"] = result["replies"][0] + failed_documents.append(replace(document, meta=new_meta)) continue for key in parsed_metadata: - document.meta[key] = parsed_metadata[key] + new_meta[key] = parsed_metadata[key] # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs - document.meta.pop("metadata_extraction_error", None) - document.meta.pop("metadata_extraction_response", None) - successful_documents.append(document) + new_meta.pop("metadata_extraction_error", None) + new_meta.pop("metadata_extraction_response", None) + successful_documents.append(replace(document, meta=new_meta)) return {"documents": successful_documents, "failed_documents": failed_documents} diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 3efde7da7..4d45e9b36 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum from typing import Any, Optional, Union @@ -204,10 +204,12 @@ class NamedEntityExtractor: f"got {len(annotations)} but expected {len(documents)}" ) + new_documents = [] for doc, doc_annotations in zip(documents, annotations): - doc.meta[self._METADATA_KEY] = doc_annotations + new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations} + new_documents.append(replace(doc, meta=new_meta)) - return {"documents": documents} + return {"documents": new_documents} def to_dict(self) -> dict[str, Any]: """ diff --git a/pyproject.toml b/pyproject.toml index 5c3c21cd4..db3ad3a0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -173,7 +173,7 @@ extra-dependencies = [ ] [tool.hatch.envs.e2e.scripts] -test = "pytest e2e" +test = "pytest {args:e2e}" [tool.hatch.envs.readme] installer = "uv" diff --git a/releasenotes/notes/fix-classifier-extractor-inplace-change-8a59fe68a1b87e4b.yaml b/releasenotes/notes/fix-classifier-extractor-inplace-change-8a59fe68a1b87e4b.yaml new file mode 100644 index 000000000..4bf57a942 --- /dev/null +++ b/releasenotes/notes/fix-classifier-extractor-inplace-change-8a59fe68a1b87e4b.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Prevented in-place mutation of input `Document` objects in all `Extractor` and `Classifier` components + by creating copies with `dataclasses.replace` before processing. diff --git a/test/components/classifiers/test_document_language_classifier.py b/test/components/classifiers/test_document_language_classifier.py index 5922fdb26..b3f259f34 100644 --- a/test/components/classifiers/test_document_language_classifier.py +++ b/test/components/classifiers/test_document_language_classifier.py @@ -42,6 +42,8 @@ class TestDocumentLanguageClassifier: result = classifier.run(documents=[english_document, german_document]) assert result["documents"][0].meta["language"] == "en" assert result["documents"][1].meta["language"] == "unmatched" + assert "language" not in english_document.meta + assert "language" not in german_document.meta def test_warning_if_no_language_detected(self, caplog): with caplog.at_level(logging.WARNING): diff --git a/test/components/classifiers/test_zero_shot_document_classifier.py b/test/components/classifiers/test_zero_shot_document_classifier.py index 0cf7ad2c7..677ca5169 100644 --- a/test/components/classifiers/test_zero_shot_document_classifier.py +++ b/test/components/classifiers/test_zero_shot_document_classifier.py @@ -135,6 +135,8 @@ class TestTransformersZeroShotDocumentClassifier: assert component.pipeline is not None assert result["documents"][0].to_dict()["classification"]["label"] == "positive" assert result["documents"][1].to_dict()["classification"]["label"] == "negative" + assert "classification" not in positive_document.to_dict() + assert "classification" not in negative_document.to_dict() @pytest.mark.integration @pytest.mark.slow @@ -150,6 +152,8 @@ class TestTransformersZeroShotDocumentClassifier: assert component.pipeline is not None assert result["documents"][0].to_dict()["classification"]["label"] == "positive" assert result["documents"][1].to_dict()["classification"]["label"] == "negative" + assert "classification" not in positive_document.to_dict() + assert "classification" not in negative_document.to_dict() def test_serialization_and_deserialization_pipeline(self): pipeline = Pipeline() diff --git a/test/components/extractors/test_llm_metadata_extractor.py b/test/components/extractors/test_llm_metadata_extractor.py index f79a08c06..943a71e5d 100644 --- a/test/components/extractors/test_llm_metadata_extractor.py +++ b/test/components/extractors/test_llm_metadata_extractor.py @@ -251,11 +251,13 @@ class TestLLMMetadataExtractor: assert failed_doc_none.id == doc_with_none_content.id assert "metadata_extraction_error" in failed_doc_none.meta assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call." + assert "metadata_extraction_error" not in doc_with_none_content.meta failed_doc_empty = result["failed_documents"][1] assert failed_doc_empty.id == doc_with_empty_content.id assert "metadata_extraction_error" in failed_doc_empty.meta assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call." + assert "metadata_extraction_error" not in doc_with_empty_content.meta # Ensure no attempt was made to call the LLM mock_chat_generator.run.assert_not_called() diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index 87888615e..06322f1ea 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -6,10 +6,12 @@ # Spacy is not installed in the test environment to keep the CI fast. # We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py. +from unittest.mock import patch + import pytest -from haystack import ComponentError, DeserializationError, Pipeline -from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend +from haystack import ComponentError, DeserializationError, Document, Pipeline +from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend from haystack.utils.auth import Secret from haystack.utils.device import ComponentDevice @@ -132,3 +134,45 @@ def test_named_entity_extractor_serde_none_device(): assert type(new_extractor._backend) == type(extractor._backend) assert new_extractor._backend.model_name == extractor._backend.model_name assert new_extractor._backend.device == extractor._backend.device + + +def test_named_entity_extractor_run(): + """Test the NamedEntityExtractor.run method with mocked model interaction.""" + documents = [Document(content="My name is Clara and I live in Berkeley, California.")] + + expected_annotations = [ + [ + NamedEntityAnnotation(entity="PER", start=11, end=16, score=0.95), + NamedEntityAnnotation(entity="LOC", start=31, end=39, score=0.88), + NamedEntityAnnotation(entity="LOC", start=41, end=51, score=0.92), + ] + ] + + extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") + + with patch.object(extractor._backend, "annotate", return_value=expected_annotations) as mock_annotate: + extractor._backend.pipeline = "mocked_pipeline" + extractor._warmed_up = True + + result = extractor.run(documents=documents, batch_size=2) + + mock_annotate.assert_called_once_with(["My name is Clara and I live in Berkeley, California."], batch_size=2) + + assert "documents" in result + assert len(result["documents"]) == 1 + + assert isinstance(result["documents"][0], Document) + assert result["documents"][0].content == documents[0].content + assert "named_entities" in result["documents"][0].meta + assert result["documents"][0].meta["named_entities"] == expected_annotations[0] + assert "named_entities" not in documents[0].meta + + +def test_named_entity_extractor_run_not_warmed_up(): + """Test that run method raises error when not warmed up.""" + extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") + + documents = [Document(content="Test document")] + + with pytest.raises(RuntimeError, match="The component NamedEntityExtractor was not warmed up"): + extractor.run(documents=documents)