feat: upgrade canals in preview (#5344)

* upgrade nodes

* linting
This commit is contained in:
ZanSara 2023-07-13 12:30:49 +02:00 committed by GitHub
parent 8750d92763
commit 7848f00d01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 31 deletions

View File

@ -1,3 +1,3 @@
from canals.component import component, ComponentInput, ComponentOutput
from canals.component import component
from haystack.preview.dataclasses import Document
from haystack.preview.pipeline import Pipeline, PipelineError, NoSuchStoreError, load_pipelines, save_pipelines

View File

@ -2,12 +2,11 @@ from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args
import logging
from pathlib import Path
from dataclasses import dataclass
import torch
import whisper
from haystack.preview import component, ComponentInput, ComponentOutput, Document
from haystack.preview import component, Document
logger = logging.getLogger(__name__)
@ -24,15 +23,21 @@ class LocalWhisperTranscriber:
[github repo](https://github.com/openai/whisper).
"""
@dataclass
class Input(ComponentInput):
class Input:
audio_files: List[Path]
whisper_params: Optional[Dict[str, Any]] = None
@dataclass
class Output(ComponentOutput):
class Output:
documents: List[Document]
@component.input
def input(self): # type: ignore
return LocalWhisperTranscriber.Input
@component.output
def output(self): # type: ignore
return LocalWhisperTranscriber.Output
def __init__(self, model_name_or_path: WhisperLocalModel = "large", device: Optional[str] = None):
"""
:param model_name_or_path: Name of the model to use. Set it to one of the following values:
@ -76,7 +81,7 @@ class LocalWhisperTranscriber:
if not data.whisper_params:
data.whisper_params = {}
documents = self.transcribe(data.audio_files, **data.whisper_params)
return LocalWhisperTranscriber.Output(documents)
return self.output(documents=documents)
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
"""

View File

@ -4,10 +4,9 @@ import os
import json
import logging
from pathlib import Path
from dataclasses import dataclass
from haystack.utils import request_with_retry
from haystack.preview import component, ComponentInput, ComponentOutput, Document
from haystack.preview import component, Document
logger = logging.getLogger(__name__)
@ -29,15 +28,21 @@ class RemoteWhisperTranscriber:
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text)
"""
@dataclass
class Input(ComponentInput):
class Input:
audio_files: List[Path]
whisper_params: Optional[Dict[str, Any]] = None
@dataclass
class Output(ComponentOutput):
class Output:
documents: List[Document]
@component.input
def input(self): # type: ignore
return RemoteWhisperTranscriber.Input
@component.output
def output(self): # type: ignore
return RemoteWhisperTranscriber.Output
def __init__(
self, api_key: str, model_name: WhisperRemoteModel = "whisper-1", api_base: str = "https://api.openai.com/v1"
):
@ -77,7 +82,7 @@ class RemoteWhisperTranscriber:
if not data.whisper_params:
data.whisper_params = {}
documents = self.transcribe(data.audio_files, **data.whisper_params)
return RemoteWhisperTranscriber.Output(documents)
return self.output(documents=documents)
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
"""

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
from haystack.preview import component, Document, ComponentInput, ComponentOutput
from haystack.preview import component, Document
from haystack.preview.document_stores import MemoryDocumentStore
@ -11,8 +10,7 @@ class MemoryRetriever:
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
"""
@dataclass
class Input(ComponentInput):
class Input:
"""
Input data for the MemoryRetriever component.
@ -29,8 +27,7 @@ class MemoryRetriever:
scale_score: bool
stores: Dict[str, Any]
@dataclass
class Output(ComponentOutput):
class Output:
"""
Output data from the MemoryRetriever component.
@ -39,6 +36,14 @@ class MemoryRetriever:
documents: List[List[Document]]
@component.input
def input(self): # type: ignore
return MemoryRetriever.Input
@component.output
def output(self): # type: ignore
return MemoryRetriever.Output
def __init__(
self,
document_store_name: str,
@ -86,4 +91,4 @@ class MemoryRetriever:
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
)
)
return MemoryRetriever.Output(documents=docs)
return self.output(documents=docs)

View File

@ -2,7 +2,6 @@ from typing import List, Dict, Any, Optional, Callable
from pathlib import Path
from canals.component import ComponentInput, ComponentOutput
from canals.pipeline import (
Pipeline as CanalsPipeline,
PipelineError,
@ -57,7 +56,7 @@ class Pipeline(CanalsPipeline):
except KeyError as e:
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
def run(self, data: Dict[str, ComponentInput], debug: bool = False) -> Dict[str, ComponentOutput]:
def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
"""
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.

View File

@ -79,7 +79,7 @@ dependencies = [
"jsonschema",
# Preview
"canals==0.2.2",
"canals>=0.3,<0.4",
# Agent events
"events",

View File

@ -1,9 +1,8 @@
from typing import Dict, Any
from dataclasses import dataclass
import pytest
from haystack.preview import Pipeline, component, NoSuchStoreError, ComponentInput, ComponentOutput
from haystack.preview import Pipeline, component, NoSuchStoreError
class MockStore:
@ -34,15 +33,21 @@ def test_pipeline_stores_in_params():
@component
class MockComponent:
@dataclass
class Input(ComponentInput):
class Input:
value: int
stores: Dict[str, Any]
@dataclass
class Output(ComponentOutput):
class Output:
value: int
@component.input
def input(self):
return MockComponent.Input
@component.output
def output(self):
return MockComponent.Output
def run(self, data: Input) -> Output:
assert data.stores == {"first_store": store_1, "second_store": store_2}
return MockComponent.Output(value=data.value)