mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-04 10:58:45 +00:00
Fix: prevent in-place mutation of documents in Document Classifiers and Extractors (#9703)
* modify Documents Classifiers and Extractors to not make in-place changes * Add e2e test for NER * Add unit test for NER * fixes + refinements --------- Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
parent
f8d3a82997
commit
b9a34dfebf
@ -2,6 +2,10 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Note: We only test the Spacy backend in this module, which is executed in the e2e environment.
|
||||
# We don't test Spacy in test/components/extractors/test_named_entity_extractor.py, which is executed in the
|
||||
# test environment. Spacy is not installed in the test environment to keep the CI fast.
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
@ -113,7 +117,15 @@ def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size, monkey
|
||||
def _extract_and_check_predictions(extractor, texts, expected, batch_size):
|
||||
docs = [Document(content=text) for text in texts]
|
||||
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
|
||||
assert all(id(a) == id(b) for a, b in zip(docs, outputs))
|
||||
for original_doc, output_doc in zip(docs, outputs):
|
||||
# we don't modify documents in place
|
||||
assert original_doc is not output_doc
|
||||
|
||||
# apart from meta, the documents should be identical
|
||||
output_doc_dict = output_doc.to_dict(flatten=False)
|
||||
output_doc_dict.pop("meta", None)
|
||||
assert original_doc.to_dict() == output_doc_dict
|
||||
|
||||
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
|
||||
|
||||
_check_predictions(predicted, expected)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import Optional
|
||||
|
||||
from haystack import Document, component, logging
|
||||
@ -89,14 +90,17 @@ class DocumentLanguageClassifier:
|
||||
output: dict[str, list[Document]] = {language: [] for language in self.languages}
|
||||
output["unmatched"] = []
|
||||
|
||||
new_documents = []
|
||||
for document in documents:
|
||||
detected_language = self._detect_language(document)
|
||||
new_meta = {**document.meta}
|
||||
if detected_language in self.languages:
|
||||
document.meta["language"] = detected_language
|
||||
new_meta["language"] = detected_language
|
||||
else:
|
||||
document.meta["language"] = "unmatched"
|
||||
new_meta["language"] = "unmatched"
|
||||
new_documents.append(replace(document, meta=new_meta))
|
||||
|
||||
return {"documents": documents}
|
||||
return {"documents": new_documents}
|
||||
|
||||
def _detect_language(self, document: Document) -> Optional[str]:
|
||||
language = None
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import Any, Optional
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
@ -232,12 +233,14 @@ class TransformersZeroShotDocumentClassifier:
|
||||
|
||||
predictions = self.pipeline(texts, self.labels, multi_label=self.multi_label, batch_size=batch_size)
|
||||
|
||||
new_documents = []
|
||||
for prediction, document in zip(predictions, documents):
|
||||
formatted_prediction = {
|
||||
"label": prediction["labels"][0],
|
||||
"score": prediction["scores"][0],
|
||||
"details": dict(zip(prediction["labels"], prediction["scores"])),
|
||||
}
|
||||
document.meta["classification"] = formatted_prediction
|
||||
new_meta = {**document.meta, "classification": formatted_prediction}
|
||||
new_documents.append(replace(document, meta=new_meta))
|
||||
|
||||
return {"documents": documents}
|
||||
return {"documents": new_documents}
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import copy
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import replace
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from jinja2 import meta
|
||||
@ -318,24 +319,25 @@ class LLMMetadataExtractor:
|
||||
successful_documents = []
|
||||
failed_documents = []
|
||||
for document, result in zip(documents, results):
|
||||
new_meta = {**document.meta}
|
||||
if "error" in result:
|
||||
document.meta["metadata_extraction_error"] = result["error"]
|
||||
document.meta["metadata_extraction_response"] = None
|
||||
failed_documents.append(document)
|
||||
new_meta["metadata_extraction_error"] = result["error"]
|
||||
new_meta["metadata_extraction_response"] = None
|
||||
failed_documents.append(replace(document, meta=new_meta))
|
||||
continue
|
||||
|
||||
parsed_metadata = self._extract_metadata(result["replies"][0].text)
|
||||
if "error" in parsed_metadata:
|
||||
document.meta["metadata_extraction_error"] = parsed_metadata["error"]
|
||||
document.meta["metadata_extraction_response"] = result["replies"][0]
|
||||
failed_documents.append(document)
|
||||
new_meta["metadata_extraction_error"] = parsed_metadata["error"]
|
||||
new_meta["metadata_extraction_response"] = result["replies"][0]
|
||||
failed_documents.append(replace(document, meta=new_meta))
|
||||
continue
|
||||
|
||||
for key in parsed_metadata:
|
||||
document.meta[key] = parsed_metadata[key]
|
||||
new_meta[key] = parsed_metadata[key]
|
||||
# Remove metadata_extraction_error and metadata_extraction_response if present from previous runs
|
||||
document.meta.pop("metadata_extraction_error", None)
|
||||
document.meta.pop("metadata_extraction_response", None)
|
||||
successful_documents.append(document)
|
||||
new_meta.pop("metadata_extraction_error", None)
|
||||
new_meta.pop("metadata_extraction_response", None)
|
||||
successful_documents.append(replace(document, meta=new_meta))
|
||||
|
||||
return {"documents": successful_documents, "failed_documents": failed_documents}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -204,10 +204,12 @@ class NamedEntityExtractor:
|
||||
f"got {len(annotations)} but expected {len(documents)}"
|
||||
)
|
||||
|
||||
new_documents = []
|
||||
for doc, doc_annotations in zip(documents, annotations):
|
||||
doc.meta[self._METADATA_KEY] = doc_annotations
|
||||
new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
|
||||
new_documents.append(replace(doc, meta=new_meta))
|
||||
|
||||
return {"documents": documents}
|
||||
return {"documents": new_documents}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -173,7 +173,7 @@ extra-dependencies = [
|
||||
]
|
||||
|
||||
[tool.hatch.envs.e2e.scripts]
|
||||
test = "pytest e2e"
|
||||
test = "pytest {args:e2e}"
|
||||
|
||||
[tool.hatch.envs.readme]
|
||||
installer = "uv"
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Prevented in-place mutation of input `Document` objects in all `Extractor` and `Classifier` components
|
||||
by creating copies with `dataclasses.replace` before processing.
|
||||
@ -42,6 +42,8 @@ class TestDocumentLanguageClassifier:
|
||||
result = classifier.run(documents=[english_document, german_document])
|
||||
assert result["documents"][0].meta["language"] == "en"
|
||||
assert result["documents"][1].meta["language"] == "unmatched"
|
||||
assert "language" not in english_document.meta
|
||||
assert "language" not in german_document.meta
|
||||
|
||||
def test_warning_if_no_language_detected(self, caplog):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
@ -135,6 +135,8 @@ class TestTransformersZeroShotDocumentClassifier:
|
||||
assert component.pipeline is not None
|
||||
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
|
||||
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
|
||||
assert "classification" not in positive_document.to_dict()
|
||||
assert "classification" not in negative_document.to_dict()
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.slow
|
||||
@ -150,6 +152,8 @@ class TestTransformersZeroShotDocumentClassifier:
|
||||
assert component.pipeline is not None
|
||||
assert result["documents"][0].to_dict()["classification"]["label"] == "positive"
|
||||
assert result["documents"][1].to_dict()["classification"]["label"] == "negative"
|
||||
assert "classification" not in positive_document.to_dict()
|
||||
assert "classification" not in negative_document.to_dict()
|
||||
|
||||
def test_serialization_and_deserialization_pipeline(self):
|
||||
pipeline = Pipeline()
|
||||
|
||||
@ -251,11 +251,13 @@ class TestLLMMetadataExtractor:
|
||||
assert failed_doc_none.id == doc_with_none_content.id
|
||||
assert "metadata_extraction_error" in failed_doc_none.meta
|
||||
assert failed_doc_none.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
|
||||
assert "metadata_extraction_error" not in doc_with_none_content.meta
|
||||
|
||||
failed_doc_empty = result["failed_documents"][1]
|
||||
assert failed_doc_empty.id == doc_with_empty_content.id
|
||||
assert "metadata_extraction_error" in failed_doc_empty.meta
|
||||
assert failed_doc_empty.meta["metadata_extraction_error"] == "Document has no content, skipping LLM call."
|
||||
assert "metadata_extraction_error" not in doc_with_empty_content.meta
|
||||
|
||||
# Ensure no attempt was made to call the LLM
|
||||
mock_chat_generator.run.assert_not_called()
|
||||
|
||||
@ -6,10 +6,12 @@
|
||||
# Spacy is not installed in the test environment to keep the CI fast.
|
||||
# We test the Spacy backend in e2e/pipelines/test_named_entity_extractor.py.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import ComponentError, DeserializationError, Pipeline
|
||||
from haystack.components.extractors import NamedEntityExtractor, NamedEntityExtractorBackend
|
||||
from haystack import ComponentError, DeserializationError, Document, Pipeline
|
||||
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.device import ComponentDevice
|
||||
|
||||
@ -132,3 +134,45 @@ def test_named_entity_extractor_serde_none_device():
|
||||
assert type(new_extractor._backend) == type(extractor._backend)
|
||||
assert new_extractor._backend.model_name == extractor._backend.model_name
|
||||
assert new_extractor._backend.device == extractor._backend.device
|
||||
|
||||
|
||||
def test_named_entity_extractor_run():
|
||||
"""Test the NamedEntityExtractor.run method with mocked model interaction."""
|
||||
documents = [Document(content="My name is Clara and I live in Berkeley, California.")]
|
||||
|
||||
expected_annotations = [
|
||||
[
|
||||
NamedEntityAnnotation(entity="PER", start=11, end=16, score=0.95),
|
||||
NamedEntityAnnotation(entity="LOC", start=31, end=39, score=0.88),
|
||||
NamedEntityAnnotation(entity="LOC", start=41, end=51, score=0.92),
|
||||
]
|
||||
]
|
||||
|
||||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||
|
||||
with patch.object(extractor._backend, "annotate", return_value=expected_annotations) as mock_annotate:
|
||||
extractor._backend.pipeline = "mocked_pipeline"
|
||||
extractor._warmed_up = True
|
||||
|
||||
result = extractor.run(documents=documents, batch_size=2)
|
||||
|
||||
mock_annotate.assert_called_once_with(["My name is Clara and I live in Berkeley, California."], batch_size=2)
|
||||
|
||||
assert "documents" in result
|
||||
assert len(result["documents"]) == 1
|
||||
|
||||
assert isinstance(result["documents"][0], Document)
|
||||
assert result["documents"][0].content == documents[0].content
|
||||
assert "named_entities" in result["documents"][0].meta
|
||||
assert result["documents"][0].meta["named_entities"] == expected_annotations[0]
|
||||
assert "named_entities" not in documents[0].meta
|
||||
|
||||
|
||||
def test_named_entity_extractor_run_not_warmed_up():
|
||||
"""Test that run method raises error when not warmed up."""
|
||||
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
|
||||
|
||||
documents = [Document(content="Test document")]
|
||||
|
||||
with pytest.raises(RuntimeError, match="The component NamedEntityExtractor was not warmed up"):
|
||||
extractor.run(documents=documents)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user