diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py index cff245b35..5aa52fde8 100644 --- a/haystack/components/classifiers/zero_shot_document_classifier.py +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -121,7 +121,6 @@ class TransformersZeroShotDocumentClassifier: self.token = token self.labels = labels self.multi_label = multi_label - component.set_output_types(self, **{label: str for label in labels}) huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs( huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {}, @@ -229,7 +228,7 @@ class TransformersZeroShotDocumentClassifier: ) texts = [ - doc.content if self.classification_field is None else doc.meta[self.classification_field] + (doc.content if self.classification_field is None else doc.meta[self.classification_field]) for doc in documents ] diff --git a/haystack/core/component/component.py b/haystack/core/component/component.py index e34efdb28..72c444fdd 100644 --- a/haystack/core/component/component.py +++ b/haystack/core/component/component.py @@ -449,6 +449,13 @@ class _Component: return {"output_1": 1, "output_2": "2"} ``` """ + has_decorator = hasattr(instance.run, "_output_types_cache") + if has_decorator: + raise ComponentError( + "Cannot call `set_output_types` on a component that already has " + "the 'output_types' decorator on its `run` method" + ) + instance.__haystack_output__ = Sockets( instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket ) diff --git a/releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml b/releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml new file mode 100644 index 000000000..2a06fadde --- /dev/null +++ b/releasenotes/notes/component-set-output-type-override-852a19b3f0621fb0.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Prevent `set_output_types`` from being called when the `output_types`` decorator is used. diff --git a/test/core/component/test_component.py b/test/core/component/test_component.py index 320e2d65c..8b4266dbb 100644 --- a/test/core/component/test_component.py +++ b/test/core/component/test_component.py @@ -315,6 +315,20 @@ def test_output_types_decorator_wrong_method(): return cls() +def test_output_types_decorator_and_set_output_types(): + @component + class MockComponent: + def __init__(self) -> None: + component.set_output_types(self, value=int) + + @component.output_types(value=int) + def run(self, value: int): + return {"value": 1} + + with pytest.raises(ComponentError, match="Cannot call `set_output_types`"): + comp = MockComponent() + + def test_output_types_decorator_mismatch_run_async_run(): @component class MockComponent: