Fix: prevent in-place mutation of documents in Document Classifiers and Extractors (#9703)

* modify Documents Classifiers and Extractors to not make in-place changes

* Add e2e test for NER

* Add unit test for NER

* fixes + refinements

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
Abdelrahman Kaseb 2025-08-12 16:20:44 +03:00 committed by GitHub
parent f8d3a82997
commit b9a34dfebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 102 additions and 22 deletions

View File

@ -2,6 +2,10 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Note: We only test the Spacy backend in this module, which is executed in the e2e environment.
# We don't test Spacy in test/components/extractors/test_named_entity_extractor.py, which is executed in the
# test environment. Spacy is not installed in the test environment to keep the CI fast.
import os import os
import pytest import pytest
@ -113,7 +117,15 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, monkey
def _extract_and_check_predictions(extractor, texts, expected, batch_size): def _extract_and_check_predictions(extractor, texts, expected, batch_size):
docs = [Document(content=text) for text in texts] docs = [Document(content=text) for text in texts]
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"] outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
assert all(id(a) == id(b) for a, b in zip(docs, outputs)) for original_doc, output_doc in zip(docs, outputs):
# we don't modify documents in place
assert original_doc is not output_doc
# apart from meta, the documents should be identical
output_doc_dict = output_doc.to_dict(flatten=False)
output_doc_dict.pop("meta", None)
assert original_doc.to_dict() == output_doc_dict
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs] predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
_check_predictions(predicted, expected) _check_predictions(predicted, expected)

View File

@ -2,6 +2,7 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import replace
from typing import Optional from typing import Optional
from haystack import Document, component, logging from haystack import Document, component, logging
@ -89,14 +90,17 @@ class DocumentLanguageClassifier:
output: dict[str, list[Document]] = {language: [] for language in self.languages} output: dict[str, list[Document]] = {language: [] for language in self.languages}
output["unmatched"] = [] output["unmatched"] = []
new_documents = []
for document in documents: for document in documents:
detected_language = self._detect_language(document) detected_language = self._detect_language(document)
new_meta = {**document.meta}
if detected_language in self.languages: if detected_language in self.languages:
document.meta["language"] = detected_language new_meta["language"] = detected_language
else: else:
document.meta["language"] = "unmatched" new_meta["language"] = "unmatched"
new_documents.append(replace(document, meta=new_meta))
return {"documents": documents} return {"documents": new_documents}
def _detect_language(self, document: Document) -> Optional[str]: def _detect_language(self, document: Document) -> Optional[str]:
language = None language = None

View File

@ -2,6 +2,7 @@
# #
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dataclasses import replace
from typing import Any, Optional from typing import Any, Optional
from haystack import Document, component, default_from_dict, default_to_dict from haystack import Document, component, default_from_dict, default_to_dict
@ -232,12 +233,14 @@ class TransformersZeroShotDocumentClassifier:
predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size) predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
new_documents = []
for prediction, document in zip(predictions, documents): for prediction, document in zip(predictions, documents):
formatted_prediction = { formatted_prediction = {
"label": prediction["labels"][0], "label": prediction["labels"][0],
"score": prediction["scores"][0], "score": prediction["scores"][0],
"details": dict(zip(prediction["labels"], prediction["scores"])), "details": dict(zip(prediction["labels"], prediction["scores"])),
} }
document.meta["classification"] = formatted_prediction new_meta = {**document.meta, "classification": formatted_prediction}
new_documents.append(replace(document, meta=new_meta))
return {"documents": documents} return {"documents": new_documents}

View File

@ -5,6 +5,7 @@
import copy import copy
import json import json
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import replace
from typing import Any, Optional, Union from typing import Any, Optional, Union
from jinja2 import meta from jinja2 import meta
@ -318,24 +319,25 @@ class LLMMetadataExtractor:
successful_documents = [] successful_documents = []
failed_documents = [] failed_documents = []
for document, result in zip(documents, results): for document, result in zip(documents, results):
new_meta = {**document.meta}
if "error" in result: if "error" in result:
document.meta["metadata_extraction_error"] = result["error"] new_meta["metadata_extraction_error"] = result["error"]
document.meta["metadata_extraction_response"] = None new_meta["metadata_extraction_response"] = None
failed_documents.append(document) failed_documents.append(replace(document, meta=new_meta))
continue continue
parsed_metadata = self._extract_metadata(result["replies"][0].text) parsed_metadata = self._extract_metadata(result["replies"][0].text)
if "error" in parsed_metadata: if "error" in parsed_metadata:
document.meta["metadata_extraction_error"] = parsed_metadata["error"] new_meta["metadata_extraction_error"] = parsed_metadata["error"]
document.meta["metadata_extraction_response"] = result["replies"][0] new_meta["metadata_extraction_response"] = result["replies"][0]
failed_documents.append(document) failed_documents.append(replace(document, meta=new_meta))
continue continue
for key in parsed_metadata: for key in parsed_metadata:
document.meta[key] = parsed_metadata[key] new_meta[key] = parsed_metadata[key]
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
document.meta.pop("metadata_extraction_error", None) new_meta.pop("metadata_extraction_error", None)
document.meta.pop("metadata_extraction_response", None) new_meta.pop("metadata_extraction_response", None)
successful_documents.append(document) successful_documents.append(replace(document, meta=new_meta))
return {"documents": successful_documents, "failed_documents": failed_documents} return {"documents": successful_documents, "failed_documents": failed_documents}

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -204,10 +204,12 @@ class NamedEntityExtractor:
f"got {len(annotations)} but expected {len(documents)}" f"got {len(annotations)} but expected {len(documents)}"
) )
new_documents = []
for doc, doc_annotations in zip(documents, annotations): for doc, doc_annotations in zip(documents, annotations):
doc.meta[self._METADATA_KEY] = doc_annotations new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
new_documents.append(replace(doc, meta=new_meta))
return {"documents": documents} return {"documents": new_documents}
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
""" """

View File

@ -173,7 +173,7 @@ extra-dependencies = [
] ]
[tool.hatch.envs.e2e.scripts] [tool.hatch.envs.e2e.scripts]
test = "pytest e2e" test = "pytest {args:e2e}"
[tool.hatch.envs.readme] [tool.hatch.envs.readme]
installer = "uv" installer = "uv"

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Prevented in-place mutation of input `Document` objects in all `Extractor` and `Classifier` components
by creating copies with `dataclasses.replace` before processing.

View File

@ -42,6 +42,8 @@ class TestDocumentLanguageClassifier:
result = classifier.run(documents=[english_document, german_document]) result = classifier.run(documents=[english_document, german_document])
assert result["documents"][0].meta["language"] == "en" assert result["documents"][0].meta["language"] == "en"
assert result["documents"][1].meta["language"] == "unmatched" assert result["documents"][1].meta["language"] == "unmatched"
assert "language" not in english_document.meta
assert "language" not in german_document.meta
def test_warning_if_no_language_detected(self, caplog): def test_warning_if_no_language_detected(self, caplog):
with caplog.at_level(logging.WARNING): with caplog.at_level(logging.WARNING):

View File

@ -135,6 +135,8 @@ class TestTransformersZeroShotDocumentClassifier:
assert component.pipeline is not None assert component.pipeline is not None
assert result["documents"][0].to_dict()["classification"]["label"] == "positive" assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
assert result["documents"][1].to_dict()["classification"]["label"] == "negative" assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
assert "classification" not in positive_document.to_dict()
assert "classification" not in negative_document.to_dict()
@pytest.mark.integration @pytest.mark.integration
@pytest.mark.slow @pytest.mark.slow
@ -150,6 +152,8 @@ class TestTransformersZeroShotDocumentClassifier:
assert component.pipeline is not None assert component.pipeline is not None
assert result["documents"][0].to_dict()["classification"]["label"] == "positive" assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
assert result["documents"][1].to_dict()["classification"]["label"] == "negative" assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
assert "classification" not in positive_document.to_dict()
assert "classification" not in negative_document.to_dict()
def test_serialization_and_deserialization_pipeline(self): def test_serialization_and_deserialization_pipeline(self):
pipeline = Pipeline() pipeline = Pipeline()

View File

@ -251,11 +251,13 @@ class TestLLMMetadataExtractor:
assert failed_doc_none.id == doc_with_none_content.id assert failed_doc_none.id == doc_with_none_content.id
assert "metadata_extraction_error" in failed_doc_none.meta assert "metadata_extraction_error" in failed_doc_none.meta
assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call." assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
assert "metadata_extraction_error" not in doc_with_none_content.meta
failed_doc_empty = result["failed_documents"][1] failed_doc_empty = result["failed_documents"][1]
assert failed_doc_empty.id == doc_with_empty_content.id assert failed_doc_empty.id == doc_with_empty_content.id
assert "metadata_extraction_error" in failed_doc_empty.meta assert "metadata_extraction_error" in failed_doc_empty.meta
assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call." assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
assert "metadata_extraction_error" not in doc_with_empty_content.meta
# Ensure no attempt was made to call the LLM # Ensure no attempt was made to call the LLM
mock_chat_generator.run.assert_not_called() mock_chat_generator.run.assert_not_called()

View File

@ -6,10 +6,12 @@
# Spacy is not installed in the test environment to keep the CI fast. # Spacy is not installed in the test environment to keep the CI fast.
# We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py. # We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py.
from unittest.mock import patch
import pytest import pytest
from haystack import ComponentError, DeserializationError, Pipeline from haystack import ComponentError, DeserializationError, Document, Pipeline
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
from haystack.utils.auth import Secret from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice from haystack.utils.device import ComponentDevice
@ -132,3 +134,45 @@ def test_named_entity_extractor_serde_none_device():
assert type(new_extractor._backend) == type(extractor._backend) assert type(new_extractor._backend) == type(extractor._backend)
assert new_extractor._backend.model_name == extractor._backend.model_name assert new_extractor._backend.model_name == extractor._backend.model_name
assert new_extractor._backend.device == extractor._backend.device assert new_extractor._backend.device == extractor._backend.device
def test_named_entity_extractor_run():
"""Test the NamedEntityExtractor.run method with mocked model interaction."""
documents = [Document(content="My name is Clara and I live in Berkeley, California.")]
expected_annotations = [
[
NamedEntityAnnotation(entity="PER", start=11, end=16, score=0.95),
NamedEntityAnnotation(entity="LOC", start=31, end=39, score=0.88),
NamedEntityAnnotation(entity="LOC", start=41, end=51, score=0.92),
]
]
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
with patch.object(extractor._backend, "annotate", return_value=expected_annotations) as mock_annotate:
extractor._backend.pipeline = "mocked_pipeline"
extractor._warmed_up = True
result = extractor.run(documents=documents, batch_size=2)
mock_annotate.assert_called_once_with(["My name is Clara and I live in Berkeley, California."], batch_size=2)
assert "documents" in result
assert len(result["documents"]) == 1
assert isinstance(result["documents"][0], Document)
assert result["documents"][0].content == documents[0].content
assert "named_entities" in result["documents"][0].meta
assert result["documents"][0].meta["named_entities"] == expected_annotations[0]
assert "named_entities" not in documents[0].meta
def test_named_entity_extractor_run_not_warmed_up():
"""Test that run method raises error when not warmed up."""
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
documents = [Document(content="Test document")]
with pytest.raises(RuntimeError, match="The component NamedEntityExtractor was not warmed up"):
extractor.run(documents=documents)