diff --git a/haystack/preview/__init__.py b/haystack/preview/__init__.py index 1cdcb4bd5..bb63c405a 100644 --- a/haystack/preview/__init__.py +++ b/haystack/preview/__init__.py @@ -1,3 +1,3 @@ -from canals.component import component, ComponentInput, ComponentOutput +from canals.component import component from haystack.preview.dataclasses import Document from haystack.preview.pipeline import Pipeline, PipelineError, NoSuchStoreError, load_pipelines, save_pipelines diff --git a/haystack/preview/components/audio/whisper_local.py b/haystack/preview/components/audio/whisper_local.py index 5f7ab6a82..14f26797e 100644 --- a/haystack/preview/components/audio/whisper_local.py +++ b/haystack/preview/components/audio/whisper_local.py @@ -2,12 +2,11 @@ from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args import logging from pathlib import Path -from dataclasses import dataclass import torch import whisper -from haystack.preview import component, ComponentInput, ComponentOutput, Document +from haystack.preview import component, Document logger = logging.getLogger(__name__) @@ -24,15 +23,21 @@ class LocalWhisperTranscriber: [github repo](https://github.com/openai/whisper). """ - @dataclass - class Input(ComponentInput): + class Input: audio_files: List[Path] whisper_params: Optional[Dict[str, Any]] = None - @dataclass - class Output(ComponentOutput): + class Output: documents: List[Document] + @component.input + def input(self): # type: ignore + return LocalWhisperTranscriber.Input + + @component.output + def output(self): # type: ignore + return LocalWhisperTranscriber.Output + def __init__(self, model_name_or_path: WhisperLocalModel = "large", device: Optional[str] = None): """ :param model_name_or_path: Name of the model to use. Set it to one of the following values: @@ -76,7 +81,7 @@ class LocalWhisperTranscriber: if not data.whisper_params: data.whisper_params = {} documents = self.transcribe(data.audio_files, **data.whisper_params) - return LocalWhisperTranscriber.Output(documents) + return self.output(documents=documents) def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]: """ diff --git a/haystack/preview/components/audio/whisper_remote.py b/haystack/preview/components/audio/whisper_remote.py index b7fcb06d7..d3925b970 100644 --- a/haystack/preview/components/audio/whisper_remote.py +++ b/haystack/preview/components/audio/whisper_remote.py @@ -4,10 +4,9 @@ import os import json import logging from pathlib import Path -from dataclasses import dataclass from haystack.utils import request_with_retry -from haystack.preview import component, ComponentInput, ComponentOutput, Document +from haystack.preview import component, Document logger = logging.getLogger(__name__) @@ -29,15 +28,21 @@ class RemoteWhisperTranscriber: [Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) """ - @dataclass - class Input(ComponentInput): + class Input: audio_files: List[Path] whisper_params: Optional[Dict[str, Any]] = None - @dataclass - class Output(ComponentOutput): + class Output: documents: List[Document] + @component.input + def input(self): # type: ignore + return RemoteWhisperTranscriber.Input + + @component.output + def output(self): # type: ignore + return RemoteWhisperTranscriber.Output + def __init__( self, api_key: str, model_name: WhisperRemoteModel = "whisper-1", api_base: str = "https://api.openai.com/v1" ): @@ -77,7 +82,7 @@ class RemoteWhisperTranscriber: if not data.whisper_params: data.whisper_params = {} documents = self.transcribe(data.audio_files, **data.whisper_params) - return RemoteWhisperTranscriber.Output(documents) + return self.output(documents=documents) def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]: """ diff --git a/haystack/preview/components/retrievers/memory.py b/haystack/preview/components/retrievers/memory.py index d6eba3ac9..1e3171b53 100644 --- a/haystack/preview/components/retrievers/memory.py +++ b/haystack/preview/components/retrievers/memory.py @@ -1,7 +1,6 @@ -from dataclasses import dataclass from typing import Dict, List, Any, Optional -from haystack.preview import component, Document, ComponentInput, ComponentOutput +from haystack.preview import component, Document from haystack.preview.document_stores import MemoryDocumentStore @@ -11,8 +10,7 @@ class MemoryRetriever: A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm. """ - @dataclass - class Input(ComponentInput): + class Input: """ Input data for the MemoryRetriever component. @@ -29,8 +27,7 @@ class MemoryRetriever: scale_score: bool stores: Dict[str, Any] - @dataclass - class Output(ComponentOutput): + class Output: """ Output data from the MemoryRetriever component. @@ -39,6 +36,14 @@ class MemoryRetriever: documents: List[List[Document]] + @component.input + def input(self): # type: ignore + return MemoryRetriever.Input + + @component.output + def output(self): # type: ignore + return MemoryRetriever.Output + def __init__( self, document_store_name: str, @@ -86,4 +91,4 @@ class MemoryRetriever: query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score ) ) - return MemoryRetriever.Output(documents=docs) + return self.output(documents=docs) diff --git a/haystack/preview/pipeline.py b/haystack/preview/pipeline.py index ca105096b..c2cd75e33 100644 --- a/haystack/preview/pipeline.py +++ b/haystack/preview/pipeline.py @@ -2,7 +2,6 @@ from typing import List, Dict, Any, Optional, Callable from pathlib import Path -from canals.component import ComponentInput, ComponentOutput from canals.pipeline import ( Pipeline as CanalsPipeline, PipelineError, @@ -57,7 +56,7 @@ class Pipeline(CanalsPipeline): except KeyError as e: raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e - def run(self, data: Dict[str, ComponentInput], debug: bool = False) -> Dict[str, ComponentOutput]: + def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]: """ Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes. diff --git a/pyproject.toml b/pyproject.toml index a085696a6..a9ca80f62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ dependencies = [ "jsonschema", # Preview - "canals==0.2.2", + "canals>=0.3,<0.4", # Agent events "events", diff --git a/test/preview/pipeline/test_pipeline.py b/test/preview/pipeline/test_pipeline.py index dfe437bd8..2a4be6291 100644 --- a/test/preview/pipeline/test_pipeline.py +++ b/test/preview/pipeline/test_pipeline.py @@ -1,9 +1,8 @@ from typing import Dict, Any -from dataclasses import dataclass import pytest -from haystack.preview import Pipeline, component, NoSuchStoreError, ComponentInput, ComponentOutput +from haystack.preview import Pipeline, component, NoSuchStoreError class MockStore: @@ -34,15 +33,21 @@ def test_pipeline_stores_in_params(): @component class MockComponent: - @dataclass - class Input(ComponentInput): + class Input: value: int stores: Dict[str, Any] - @dataclass - class Output(ComponentOutput): + class Output: value: int + @component.input + def input(self): + return MockComponent.Input + + @component.output + def output(self): + return MockComponent.Output + def run(self, data: Input) -> Output: assert data.stores == {"first_store": store_1, "second_store": store_2} return MockComponent.Output(value=data.value)