Fix: prevent in-place mutation of documents in Document Classifiers and Extractors (#9703)

* modify Documents Classifiers and Extractors to not make in-place changes

* Add e2e test for NER

* Add unit test for NER

* fixes + refinements

---------

Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
Abdelrahman Kaseb 2025-08-12 16:20:44 +03:00 committed by GitHub
parent f8d3a82997
commit b9a34dfebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 102 additions and 22 deletions

View File

@ -2,6 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
# Note: We only test the Spacy backend in this module, which is executed in the e2e environment.
# We don't test Spacy in test/components/extractors/test_named_entity_extractor.py, which is executed in the
# test environment. Spacy is not installed in the test environment to keep the CI fast.
import os
import pytest
@ -113,7 +117,15 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, monkey
def _extract_and_check_predictions(extractor, texts, expected, batch_size):
docs = [Document(content=text) for text in texts]
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
assert all(id(a) == id(b) for a, b in zip(docs, outputs))
for original_doc, output_doc in zip(docs, outputs):
# we don't modify documents in place
assert original_doc is not output_doc
# apart from meta, the documents should be identical
output_doc_dict = output_doc.to_dict(flatten=False)
output_doc_dict.pop("meta", None)
assert original_doc.to_dict() == output_doc_dict
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
_check_predictions(predicted, expected)

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import replace
from typing import Optional
from haystack import Document, component, logging
@ -89,14 +90,17 @@ class DocumentLanguageClassifier:
output: dict[str, list[Document]] = {language: [] for language in self.languages}
output["unmatched"] = []
new_documents = []
for document in documents:
detected_language = self._detect_language(document)
new_meta = {**document.meta}
if detected_language in self.languages:
document.meta["language"] = detected_language
new_meta["language"] = detected_language
else:
document.meta["language"] = "unmatched"
new_meta["language"] = "unmatched"
new_documents.append(replace(document, meta=new_meta))
return {"documents": documents}
return {"documents": new_documents}
def _detect_language(self, document: Document) -> Optional[str]:
language = None

View File

@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
from dataclasses import replace
from typing import Any, Optional
from haystack import Document, component, default_from_dict, default_to_dict
@ -232,12 +233,14 @@ class TransformersZeroShotDocumentClassifier:
predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
new_documents = []
for prediction, document in zip(predictions, documents):
formatted_prediction = {
"label": prediction["labels"][0],
"score": prediction["scores"][0],
"details": dict(zip(prediction["labels"], prediction["scores"])),
}
document.meta["classification"] = formatted_prediction
new_meta = {**document.meta, "classification": formatted_prediction}
new_documents.append(replace(document, meta=new_meta))
return {"documents": documents}
return {"documents": new_documents}

View File

@ -5,6 +5,7 @@
import copy
import json
from concurrent.futures import ThreadPoolExecutor
from dataclasses import replace
from typing import Any, Optional, Union
from jinja2 import meta
@ -318,24 +319,25 @@ class LLMMetadataExtractor:
successful_documents = []
failed_documents = []
for document, result in zip(documents, results):
new_meta = {**document.meta}
if "error" in result:
document.meta["metadata_extraction_error"] = result["error"]
document.meta["metadata_extraction_response"] = None
failed_documents.append(document)
new_meta["metadata_extraction_error"] = result["error"]
new_meta["metadata_extraction_response"] = None
failed_documents.append(replace(document, meta=new_meta))
continue
parsed_metadata = self._extract_metadata(result["replies"][0].text)
if "error" in parsed_metadata:
document.meta["metadata_extraction_error"] = parsed_metadata["error"]
document.meta["metadata_extraction_response"] = result["replies"][0]
failed_documents.append(document)
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
new_meta["metadata_extraction_response"] = result["replies"][0]
failed_documents.append(replace(document, meta=new_meta))
continue
for key in parsed_metadata:
document.meta[key] = parsed_metadata[key]
new_meta[key] = parsed_metadata[key]
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
document.meta.pop("metadata_extraction_error", None)
document.meta.pop("metadata_extraction_response", None)
successful_documents.append(document)
new_meta.pop("metadata_extraction_error", None)
new_meta.pop("metadata_extraction_response", None)
successful_documents.append(replace(document, meta=new_meta))
return {"documents": successful_documents, "failed_documents": failed_documents}

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, replace
from enum import Enum
from typing import Any, Optional, Union
@ -204,10 +204,12 @@ class NamedEntityExtractor:
f"got {len(annotations)} but expected {len(documents)}"
)
new_documents = []
for doc, doc_annotations in zip(documents, annotations):
doc.meta[self._METADATA_KEY] = doc_annotations
new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
new_documents.append(replace(doc, meta=new_meta))
return {"documents": documents}
return {"documents": new_documents}
def to_dict(self) -> dict[str, Any]:
"""

View File

@ -173,7 +173,7 @@ extra-dependencies = [
]
[tool.hatch.envs.e2e.scripts]
test = "pytest e2e"
test = "pytest {args:e2e}"
[tool.hatch.envs.readme]
installer = "uv"

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Prevented in-place mutation of input `Document` objects in all `Extractor` and `Classifier` components
by creating copies with `dataclasses.replace` before processing.

View File

@ -42,6 +42,8 @@ class TestDocumentLanguageClassifier:
result = classifier.run(documents=[english_document, german_document])
assert result["documents"][0].meta["language"] == "en"
assert result["documents"][1].meta["language"] == "unmatched"
assert "language" not in english_document.meta
assert "language" not in german_document.meta
def test_warning_if_no_language_detected(self, caplog):
with caplog.at_level(logging.WARNING):

View File

@ -135,6 +135,8 @@ class TestTransformersZeroShotDocumentClassifier:
assert component.pipeline is not None
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
assert "classification" not in positive_document.to_dict()
assert "classification" not in negative_document.to_dict()
@pytest.mark.integration
@pytest.mark.slow
@ -150,6 +152,8 @@ class TestTransformersZeroShotDocumentClassifier:
assert component.pipeline is not None
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
assert "classification" not in positive_document.to_dict()
assert "classification" not in negative_document.to_dict()
def test_serialization_and_deserialization_pipeline(self):
pipeline = Pipeline()

View File

@ -251,11 +251,13 @@ class TestLLMMetadataExtractor:
assert failed_doc_none.id == doc_with_none_content.id
assert "metadata_extraction_error" in failed_doc_none.meta
assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
assert "metadata_extraction_error" not in doc_with_none_content.meta
failed_doc_empty = result["failed_documents"][1]
assert failed_doc_empty.id == doc_with_empty_content.id
assert "metadata_extraction_error" in failed_doc_empty.meta
assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
assert "metadata_extraction_error" not in doc_with_empty_content.meta
# Ensure no attempt was made to call the LLM
mock_chat_generator.run.assert_not_called()

View File

@ -6,10 +6,12 @@
# Spacy is not installed in the test environment to keep the CI fast.
# We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py.
from unittest.mock import patch
import pytest
from haystack import ComponentError, DeserializationError, Pipeline
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
from haystack import ComponentError, DeserializationError, Document, Pipeline
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
from haystack.utils.auth import Secret
from haystack.utils.device import ComponentDevice
@ -132,3 +134,45 @@ def test_named_entity_extractor_serde_none_device():
assert type(new_extractor._backend) == type(extractor._backend)
assert new_extractor._backend.model_name == extractor._backend.model_name
assert new_extractor._backend.device == extractor._backend.device
def test_named_entity_extractor_run():
"""Test the NamedEntityExtractor.run method with mocked model interaction."""
documents = [Document(content="My name is Clara and I live in Berkeley, California.")]
expected_annotations = [
[
NamedEntityAnnotation(entity="PER", start=11, end=16, score=0.95),
NamedEntityAnnotation(entity="LOC", start=31, end=39, score=0.88),
NamedEntityAnnotation(entity="LOC", start=41, end=51, score=0.92),
]
]
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
with patch.object(extractor._backend, "annotate", return_value=expected_annotations) as mock_annotate:
extractor._backend.pipeline = "mocked_pipeline"
extractor._warmed_up = True
result = extractor.run(documents=documents, batch_size=2)
mock_annotate.assert_called_once_with(["My name is Clara and I live in Berkeley, California."], batch_size=2)
assert "documents" in result
assert len(result["documents"]) == 1
assert isinstance(result["documents"][0], Document)
assert result["documents"][0].content == documents[0].content
assert "named_entities" in result["documents"][0].meta
assert result["documents"][0].meta["named_entities"] == expected_annotations[0]
assert "named_entities" not in documents[0].meta
def test_named_entity_extractor_run_not_warmed_up():
"""Test that run method raises error when not warmed up."""
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
documents = [Document(content="Test document")]
with pytest.raises(RuntimeError, match="The component NamedEntityExtractor was not warmed up"):
extractor.run(documents=documents)