mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-05 19:47:45 +00:00
parent
8750d92763
commit
7848f00d01
@ -1,3 +1,3 @@
|
||||
from canals.component import component, ComponentInput, ComponentOutput
|
||||
from canals.component import component
|
||||
from haystack.preview.dataclasses import Document
|
||||
from haystack.preview.pipeline import Pipeline, PipelineError, NoSuchStoreError, load_pipelines, save_pipelines
|
||||
|
||||
@ -2,12 +2,11 @@ from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import whisper
|
||||
|
||||
from haystack.preview import component, ComponentInput, ComponentOutput, Document
|
||||
from haystack.preview import component, Document
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -24,15 +23,21 @@ class LocalWhisperTranscriber:
|
||||
[github repo](https://github.com/openai/whisper).
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Input(ComponentInput):
|
||||
class Input:
|
||||
audio_files: List[Path]
|
||||
whisper_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class Output(ComponentOutput):
|
||||
class Output:
|
||||
documents: List[Document]
|
||||
|
||||
@component.input
|
||||
def input(self): # type: ignore
|
||||
return LocalWhisperTranscriber.Input
|
||||
|
||||
@component.output
|
||||
def output(self): # type: ignore
|
||||
return LocalWhisperTranscriber.Output
|
||||
|
||||
def __init__(self, model_name_or_path: WhisperLocalModel = "large", device: Optional[str] = None):
|
||||
"""
|
||||
:param model_name_or_path: Name of the model to use. Set it to one of the following values:
|
||||
@ -76,7 +81,7 @@ class LocalWhisperTranscriber:
|
||||
if not data.whisper_params:
|
||||
data.whisper_params = {}
|
||||
documents = self.transcribe(data.audio_files, **data.whisper_params)
|
||||
return LocalWhisperTranscriber.Output(documents)
|
||||
return self.output(documents=documents)
|
||||
|
||||
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
|
||||
"""
|
||||
|
||||
@ -4,10 +4,9 @@ import os
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
from haystack.utils import request_with_retry
|
||||
from haystack.preview import component, ComponentInput, ComponentOutput, Document
|
||||
from haystack.preview import component, Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,15 +28,21 @@ class RemoteWhisperTranscriber:
|
||||
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text)
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Input(ComponentInput):
|
||||
class Input:
|
||||
audio_files: List[Path]
|
||||
whisper_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class Output(ComponentOutput):
|
||||
class Output:
|
||||
documents: List[Document]
|
||||
|
||||
@component.input
|
||||
def input(self): # type: ignore
|
||||
return RemoteWhisperTranscriber.Input
|
||||
|
||||
@component.output
|
||||
def output(self): # type: ignore
|
||||
return RemoteWhisperTranscriber.Output
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model_name: WhisperRemoteModel = "whisper-1", api_base: str = "https://api.openai.com/v1"
|
||||
):
|
||||
@ -77,7 +82,7 @@ class RemoteWhisperTranscriber:
|
||||
if not data.whisper_params:
|
||||
data.whisper_params = {}
|
||||
documents = self.transcribe(data.audio_files, **data.whisper_params)
|
||||
return RemoteWhisperTranscriber.Output(documents)
|
||||
return self.output(documents=documents)
|
||||
|
||||
def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
|
||||
"""
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from haystack.preview import component, Document, ComponentInput, ComponentOutput
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview.document_stores import MemoryDocumentStore
|
||||
|
||||
|
||||
@ -11,8 +10,7 @@ class MemoryRetriever:
|
||||
A component for retrieving documents from a MemoryDocumentStore using the BM25 algorithm.
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class Input(ComponentInput):
|
||||
class Input:
|
||||
"""
|
||||
Input data for the MemoryRetriever component.
|
||||
|
||||
@ -29,8 +27,7 @@ class MemoryRetriever:
|
||||
scale_score: bool
|
||||
stores: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class Output(ComponentOutput):
|
||||
class Output:
|
||||
"""
|
||||
Output data from the MemoryRetriever component.
|
||||
|
||||
@ -39,6 +36,14 @@ class MemoryRetriever:
|
||||
|
||||
documents: List[List[Document]]
|
||||
|
||||
@component.input
|
||||
def input(self): # type: ignore
|
||||
return MemoryRetriever.Input
|
||||
|
||||
@component.output
|
||||
def output(self): # type: ignore
|
||||
return MemoryRetriever.Output
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
document_store_name: str,
|
||||
@ -86,4 +91,4 @@ class MemoryRetriever:
|
||||
query=query, filters=data.filters, top_k=data.top_k, scale_score=data.scale_score
|
||||
)
|
||||
)
|
||||
return MemoryRetriever.Output(documents=docs)
|
||||
return self.output(documents=docs)
|
||||
|
||||
@ -2,7 +2,6 @@ from typing import List, Dict, Any, Optional, Callable
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from canals.component import ComponentInput, ComponentOutput
|
||||
from canals.pipeline import (
|
||||
Pipeline as CanalsPipeline,
|
||||
PipelineError,
|
||||
@ -57,7 +56,7 @@ class Pipeline(CanalsPipeline):
|
||||
except KeyError as e:
|
||||
raise NoSuchStoreError(f"No store named '{name}' is connected to this pipeline.") from e
|
||||
|
||||
def run(self, data: Dict[str, ComponentInput], debug: bool = False) -> Dict[str, ComponentOutput]:
|
||||
def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Wrapper on top of Canals Pipeline.run(). Adds the `stores` parameter to all nodes.
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ dependencies = [
|
||||
"jsonschema",
|
||||
|
||||
# Preview
|
||||
"canals==0.2.2",
|
||||
"canals>=0.3,<0.4",
|
||||
|
||||
# Agent events
|
||||
"events",
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.preview import Pipeline, component, NoSuchStoreError, ComponentInput, ComponentOutput
|
||||
from haystack.preview import Pipeline, component, NoSuchStoreError
|
||||
|
||||
|
||||
class MockStore:
|
||||
@ -34,15 +33,21 @@ def test_pipeline_stores_in_params():
|
||||
|
||||
@component
|
||||
class MockComponent:
|
||||
@dataclass
|
||||
class Input(ComponentInput):
|
||||
class Input:
|
||||
value: int
|
||||
stores: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class Output(ComponentOutput):
|
||||
class Output:
|
||||
value: int
|
||||
|
||||
@component.input
|
||||
def input(self):
|
||||
return MockComponent.Input
|
||||
|
||||
@component.output
|
||||
def output(self):
|
||||
return MockComponent.Output
|
||||
|
||||
def run(self, data: Input) -> Output:
|
||||
assert data.stores == {"first_store": store_1, "second_store": store_2}
|
||||
return MockComponent.Output(value=data.value)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user