fix: Prevent set_output_types from being called when the output_types decorator is used (#8376)

This commit is contained in:
Madeesh Kannan 2024-09-18 13:05:31 +02:00 committed by GitHub
parent 117c298145
commit b22014b915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 2 deletions

View File

@ -121,7 +121,6 @@ class TransformersZeroShotDocumentClassifier:
self.token = token
self.labels = labels
self.multi_label = multi_label
component.set_output_types(self, **{label: str for label in labels})
huggingface_pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
@ -229,7 +228,7 @@ class TransformersZeroShotDocumentClassifier:
)
texts = [
doc.content if self.classification_field is None else doc.meta[self.classification_field]
(doc.content if self.classification_field is None else doc.meta[self.classification_field])
for doc in documents
]

View File

@ -449,6 +449,13 @@ class _Component:
return {"output_1": 1, "output_2": "2"}
```
"""
has_decorator = hasattr(instance.run, "_output_types_cache")
if has_decorator:
raise ComponentError(
"Cannot call `set_output_types` on a component that already has "
"the 'output_types' decorator on its `run` method"
)
instance.__haystack_output__ = Sockets(
instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
)

View File

@ -0,0 +1,4 @@
---
fixes:
- |
Prevent `set_output_types`` from being called when the `output_types`` decorator is used.

View File

@ -315,6 +315,20 @@ def test_output_types_decorator_wrong_method():
return cls()
def test_output_types_decorator_and_set_output_types():
@component
class MockComponent:
def __init__(self) -> None:
component.set_output_types(self, value=int)
@component.output_types(value=int)
def run(self, value: int):
return {"value": 1}
with pytest.raises(ComponentError, match="Cannot call `set_output_types`"):
comp = MockComponent()
def test_output_types_decorator_mismatch_run_async_run():
@component
class MockComponent: