refactor: Remove reimplementations of default from_dict/to_dict and corresponding tests in 2.0 (#6108)

* whisper transcriber

* remove from/to_dict from builders

* remove from/to_dict from embedders

* remove from/to_dict from fetcher, file_converters

* remove from/to_dict from generators, preprocessors

* remove from/to_dict from ranker, reader

* remove from/to_dict from router, sampler, websearch

* pylint

* reno

* refactor import

* remove unused import
This commit is contained in:
Julian Risch 2023-10-19 11:17:02 +02:00 committed by GitHub
parent 6df077cbb4
commit 9f3b6512be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 49 additions and 926 deletions

View File

@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args
import logging
from pathlib import Path
from haystack.preview import component, Document, default_to_dict, default_from_dict, ComponentError
from haystack.preview import component, Document, default_to_dict, ComponentError
from haystack.preview.lazy_imports import LazyImport
with LazyImport("Run 'pip install openai-whisper'") as whisper_import:
@ -66,13 +66,6 @@ class LocalWhisperTranscriber:
self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""

View File

@ -6,7 +6,7 @@ import logging
from pathlib import Path
from haystack.preview.utils import request_with_retry
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview import component, Document, default_to_dict
logger = logging.getLogger(__name__)
@ -54,25 +54,6 @@ class RemoteWhisperTranscriber:
self.api_base = api_base
self.whisper_params = whisper_params or {}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
model_name=self.model_name,
api_key=self.api_key,
api_base=self.api_base,
whisper_params=self.whisper_params,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""
@ -145,3 +126,12 @@ class RemoteWhisperTranscriber:
transcriptions.append(transcription)
return transcriptions
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params
)

View File

@ -2,7 +2,7 @@ import logging
import re
from typing import List, Dict, Any, Optional
from haystack.preview import component, GeneratedAnswer, Document, default_to_dict, default_from_dict
from haystack.preview import component, GeneratedAnswer, Document
logger = logging.getLogger(__name__)
@ -107,19 +107,6 @@ class AnswerBuilder:
return {"answers": all_answers}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AnswerBuilder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@staticmethod
def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str:
"""

View File

@ -3,7 +3,7 @@ from typing import Dict, Any
from jinja2 import Template, meta
from haystack.preview import component
from haystack.preview import default_to_dict, default_from_dict
from haystack.preview import default_to_dict
@component
@ -36,10 +36,6 @@ class PromptBuilder:
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, template=self._template_string)
@classmethod
def from_dict(cls, data) -> "PromptBuilder":
return default_from_dict(cls, data)
@component.output_types(prompt=str)
def run(self, **kwargs):
return {"prompt": self.template.render(kwargs)}

View File

@ -5,7 +5,7 @@ import openai
from tqdm import tqdm
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview import component, Document, default_to_dict
@component
@ -89,13 +89,6 @@ class OpenAIDocumentEmbedder:
embedding_separator=self.embedding_separator,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.

View File

@ -3,7 +3,7 @@ import os
import openai
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview import component, default_to_dict
@component
@ -66,13 +66,6 @@ class OpenAITextEmbedder:
self, model_name=self.model_name, organization=self.organization, prefix=self.prefix, suffix=self.suffix
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAITextEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(embedding=List[float], metadata=Dict[str, Any])
def run(self, text: str):
"""Embed a string."""

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Dict, Any
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview import component, Document, default_to_dict
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -81,13 +81,6 @@ class SentenceTransformersDocumentEmbedder:
embedding_separator=self.embedding_separator,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the embedding backend.

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Dict, Any
from haystack.preview import component, default_to_dict, default_from_dict
from haystack.preview import component, default_to_dict
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
@ -72,13 +72,6 @@ class SentenceTransformersTextEmbedder:
normalize_embeddings=self.normalize_embeddings,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the embedding backend.

View File

@ -1,14 +1,14 @@
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import requests
from requests import Response
from requests.exceptions import HTTPError
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview import component
from haystack.preview.dataclasses import ByteStream
from haystack.preview.version import __version__
@ -95,25 +95,6 @@ class LinkContentFetcher:
self._get_response: Callable = get_response
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
raise_on_failure=self.raise_on_failure,
user_agents=self.user_agents,
retry_attempts=self.retry_attempts,
timeout=self.timeout,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LinkContentFetcher":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(streams=List[ByteStream])
def run(self, urls: List[str]):
"""

View File

@ -2,7 +2,7 @@ from pathlib import Path
from typing import List, Union, Optional, Dict, Any
from haystack.preview.lazy_imports import LazyImport
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview import component, Document, default_to_dict
with LazyImport(message="Run 'pip install azure-ai-formrecognizer>=3.2.0b2'") as azure_import:
@ -84,13 +84,6 @@ class AzureOCRDocumentConverter:
self, endpoint=self.endpoint, api_key=self.api_key, model_id=self.model_id, id_hash_keys=self.id_hash_keys
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOCRDocumentConverter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@staticmethod
def _convert_azure_result_to_document(
result: "AnalyzeResult", id_hash_keys: List[str], file_suffix: str

View File

@ -1,8 +1,8 @@
import logging
from typing import List, Optional, Dict, Any, Union
from typing import List, Optional, Union
from pathlib import Path
from haystack.preview import Document, component, default_to_dict, default_from_dict
from haystack.preview import Document, component
from haystack.preview.dataclasses import ByteStream
from haystack.preview.lazy_imports import LazyImport
@ -27,15 +27,6 @@ class HTMLToDocument:
boilerpy3_import.check()
self.id_hash_keys = id_hash_keys or []
def to_dict(self) -> Dict[str, Any]:
"""Serialize the component to a dictionary."""
return default_to_dict(self, id_hash_keys=self.id_hash_keys)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HTMLToDocument":
"""Deserialize the component from a dictionary."""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, sources: List[Union[str, Path, ByteStream]]):
"""

View File

@ -1,11 +1,11 @@
import io
import logging
from typing import List, Optional, Dict, Any, Union
from typing import List, Optional, Union
from pathlib import Path
from haystack.preview.dataclasses import ByteStream
from haystack.preview.lazy_imports import LazyImport
from haystack.preview import Document, component, default_to_dict, default_from_dict
from haystack.preview import Document, component
with LazyImport("Run 'pip install pypdf'") as pypdf_import:
from pypdf import PdfReader
@ -30,22 +30,6 @@ class PyPDFToDocument:
pypdf_import.check()
self.id_hash_keys = id_hash_keys or []
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The dictionary containing the component's data.
"""
return default_to_dict(self, id_hash_keys=self.id_hash_keys)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "PyPDFToDocument":
"""
Deserialize this component from a dictionary.
:param data: The dictionary containing the component's data.
:return: The component instance.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, sources: List[Union[str, Path, ByteStream]], id_hash_keys: Optional[List[str]] = None):
"""

View File

@ -1,9 +1,9 @@
import logging
from pathlib import Path
from typing import Optional, List, Union, Dict, Any
from typing import Optional, List, Union
from haystack.preview.lazy_imports import LazyImport
from haystack.preview import component, Document, default_to_dict, default_from_dict
from haystack.preview import component, Document
with LazyImport("Run 'pip install tika'") as tika_import:
@ -70,16 +70,3 @@ class TikaDocumentConverter:
logger.error("Could not convert file at '%s' to Document. Error: %s", str(path), e)
return {"documents": documents}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, tika_url=self.tika_url, id_hash_keys=self.id_hash_keys)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TikaDocumentConverter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

View File

@ -1,12 +1,12 @@
import logging
from pathlib import Path
from typing import Optional, List, Union, Dict, Any
from typing import Optional, List, Union, Dict
from canals.errors import PipelineRuntimeError
from tqdm import tqdm
from haystack.preview.lazy_imports import LazyImport
from haystack.preview import Document, component, default_to_dict, default_from_dict
from haystack.preview import Document, component
with LazyImport("Run 'pip install langdetect'") as langdetect_import:
import langdetect
@ -61,27 +61,6 @@ class TextFileToDocument:
self.id_hash_keys = id_hash_keys or []
self.progress_bar = progress_bar
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
encoding=self.encoding,
remove_numeric_tables=self.remove_numeric_tables,
numeric_row_threshold=self.numeric_row_threshold,
valid_languages=self.valid_languages,
id_hash_keys=self.id_hash_keys,
progress_bar=self.progress_bar,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextFileToDocument":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(
self,

View File

@ -2,7 +2,7 @@ import logging
from typing import Any, Dict, List, Literal, Optional, Union
from copy import deepcopy
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview import component, default_to_dict
from haystack.preview.lazy_imports import LazyImport
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
@ -170,13 +170,6 @@ class HuggingFaceLocalGenerator:
stop_words=self.stop_words,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(replies=List[str])
def run(self, prompt: str):
if self.pipeline is None:

View File

@ -3,9 +3,9 @@ import re
from copy import deepcopy
from functools import partial, reduce
from itertools import chain
from typing import Any, Dict, Generator, List, Optional, Set
from typing import Generator, List, Optional, Set
from haystack.preview import Document, component, default_from_dict, default_to_dict
from haystack.preview import Document, component
logger = logging.getLogger(__name__)
@ -91,26 +91,6 @@ class DocumentCleaner:
return {"documents": cleaned_docs}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self,
remove_empty_lines=self.remove_empty_lines,
remove_extra_whitespaces=self.remove_extra_whitespaces,
remove_repeated_substrings=self.remove_repeated_substrings,
remove_substrings=self.remove_substrings,
remove_regex=self.remove_regex,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "DocumentCleaner":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def _remove_empty_lines(self, text: str) -> str:
"""
Remove empty lines and lines that contain nothing but whitespaces from text.

View File

@ -1,9 +1,9 @@
from copy import deepcopy
from typing import List, Dict, Any, Literal
from typing import List, Literal
from more_itertools import windowed
from haystack.preview import component, Document, default_from_dict, default_to_dict
from haystack.preview import component, Document
@component
@ -61,21 +61,6 @@ class TextDocumentSplitter:
split_docs += [Document(text=txt, metadata=metadata, id_hash_keys=id_hash_keys) for txt in text_splits]
return {"documents": split_docs}
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self, split_by=self.split_by, split_length=self.split_length, split_overlap=self.split_overlap
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextDocumentSplitter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage"]) -> List[str]:
if split_by == "passage":
split_at = "\n\n"

View File

@ -1,7 +1,7 @@
import logging
from typing import List, Dict, Any, Optional
from typing import List, Dict, Optional
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview import component
from haystack.preview.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
@ -63,19 +63,6 @@ class TextLanguageClassifier:
return output
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, languages=self.languages)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextLanguageClassifier":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def detect_language(self, text: str) -> Optional[str]:
try:
language = langdetect.detect(text)

View File

@ -2,7 +2,7 @@ import logging
from pathlib import Path
from typing import List, Union, Dict, Any, Optional
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.preview import ComponentError, Document, component, default_to_dict
from haystack.preview.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
@ -89,13 +89,6 @@ class SimilarityRanker:
top_k=self.top_k,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
"""

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import math
import warnings
from haystack.preview import component, default_from_dict, default_to_dict, ComponentError, Document, ExtractedAnswer
from haystack.preview import component, default_to_dict, ComponentError, Document, ExtractedAnswer
from haystack.preview.lazy_imports import LazyImport
with LazyImport(
@ -97,13 +97,6 @@ class ExtractiveReader:
calibration_factor=self.calibration_factor,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExtractiveReader":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def warm_up(self):
if self.model is None:
if torch.cuda.is_available():

View File

@ -2,9 +2,9 @@ import logging
import mimetypes
from collections import defaultdict
from pathlib import Path
from typing import List, Union, Optional, Dict, Any
from typing import List, Union, Optional, Dict
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview import component
from haystack.preview.dataclasses import ByteStream
logger = logging.getLogger(__name__)
@ -42,19 +42,6 @@ class FileTypeRouter:
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
self.mime_types = mime_types
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, mime_types=self.mime_types)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
"""
Categorizes the provided data sources by their MIME types.

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, List
from typing import Dict, List
from haystack.preview import component, default_from_dict, default_to_dict, Document
from haystack.preview import component, Document
from haystack.preview.utils.filters import document_matches_filter
@ -52,16 +52,3 @@ class MetadataRouter:
output["unmatched"] = unmatched_documents
return output
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, rules=self.rules)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MetadataRouter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

View File

@ -1,7 +1,7 @@
import logging
from typing import List, Optional, Dict, Any
from typing import List, Optional
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.preview import ComponentError, Document, component
from haystack.preview.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
@ -48,19 +48,6 @@ class TopPSampler:
self.top_p = top_p
self.score_field = score_field
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, top_p=self.top_p, score_field=self.score_field)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TopPSampler":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(self, documents: List[Document], top_p: Optional[float] = None):
"""

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Any
import requests
from haystack.preview import Document, component, default_from_dict, default_to_dict, ComponentError
from haystack.preview import Document, component, default_to_dict, ComponentError
logger = logging.getLogger(__name__)
@ -58,13 +58,6 @@ class SerperDevWebSearch:
search_params=self.search_params,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str])
def run(self, query: str):
"""

View File

@ -0,0 +1,4 @@
---
preview:
- |
Removed implementations of from_dict and to_dict from all components where they had the same effect as the default implementation from Canals: https://github.com/deepset-ai/canals/blob/main/canals/serialization.py#L12-L13 This refactoring does not change the behavior of the components.

View File

@ -53,21 +53,6 @@ class TestLocalWhisperTranscriber:
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "LocalWhisperTranscriber",
"init_parameters": {
"model_name_or_path": "tiny",
"device": "cuda",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
transcriber = LocalWhisperTranscriber.from_dict(data)
assert transcriber.model_name == "tiny"
assert transcriber.device == torch.device("cuda")
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
@pytest.mark.unit
def test_warmup(self):
with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper:

View File

@ -33,7 +33,6 @@ class TestRemoteWhisperTranscriber:
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://api.openai.com/v1",
"whisper_params": {},
},
@ -52,29 +51,11 @@ class TestRemoteWhisperTranscriber:
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "RemoteWhisperTranscriber",
"init_parameters": {
"model_name": "whisper-1",
"api_key": "test",
"api_base": "https://my.api.base/something_else/v3",
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
},
}
transcriber = RemoteWhisperTranscriber.from_dict(data)
assert transcriber.model_name == "whisper-1"
assert transcriber.api_key == "test"
assert transcriber.api_base == "https://my.api.base/something_else/v3"
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
@pytest.mark.unit
def test_run_with_path(self, preview_samples_path):
mock_response = MagicMock()

View File

@ -7,31 +7,6 @@ from haystack.preview.components.builders.answer_builder import AnswerBuilder
class TestAnswerBuilder:
@pytest.mark.unit
def test_to_dict(self):
component = AnswerBuilder()
data = component.to_dict()
assert data == {"type": "AnswerBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = AnswerBuilder(pattern="pattern", reference_pattern="reference_pattern")
data = component.to_dict()
assert data == {
"type": "AnswerBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "AnswerBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}
component = AnswerBuilder.from_dict(data)
assert component.pattern == "pattern"
assert component.reference_pattern == "reference_pattern"
@pytest.mark.unit
def test_run_unmatching_input_len(self):
component = AnswerBuilder()

View File

@ -16,13 +16,6 @@ def test_to_dict():
assert res == {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}}
@pytest.mark.unit
def test_from_dict():
data = {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}}
builder = PromptBuilder.from_dict(data)
builder._template_string == "This is a {{ variable }}"
@pytest.mark.unit
def test_run():
builder = PromptBuilder(template="This is a {{ variable }}")

View File

@ -118,53 +118,6 @@ class TestOpenAIDocumentEmbedder:
},
}
@pytest.mark.unit
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "OpenAIDocumentEmbedder",
"init_parameters": {
"model_name": "model",
"organization": "my-org",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"metadata_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
}
component = OpenAIDocumentEmbedder.from_dict(data)
assert openai.api_key == "fake-api-key"
assert component.model_name == "model"
assert component.organization == "my-org"
assert openai.organization == "my-org"
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.metadata_fields_to_embed == ["test_field"]
assert component.embedding_separator == " | "
@pytest.mark.unit
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "OpenAIDocumentEmbedder",
"init_parameters": {
"model_name": "model",
"organization": "my-org",
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"metadata_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
},
}
with pytest.raises(ValueError, match="OpenAIDocumentEmbedder expects an OpenAI API key"):
OpenAIDocumentEmbedder.from_dict(data)
@pytest.mark.unit
def test_prepare_texts_to_embed_w_metadata(self):
documents = [

View File

@ -86,41 +86,6 @@ class TestOpenAITextEmbedder:
},
}
@pytest.mark.unit
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "OpenAITextEmbedder",
"init_parameters": {
"model_name": "model",
"organization": "fake-organization",
"prefix": "prefix",
"suffix": "suffix",
},
}
component = OpenAITextEmbedder.from_dict(data)
assert openai.api_key == "fake-api-key"
assert component.model_name == "model"
assert component.organization == "fake-organization"
assert openai.organization == "fake-organization"
assert component.prefix == "prefix"
assert component.suffix == "suffix"
@pytest.mark.unit
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "OpenAITextEmbedder",
"init_parameters": {
"model_name": "model",
"organization": "fake-organization",
"prefix": "prefix",
"suffix": "suffix",
},
}
with pytest.raises(ValueError, match="OpenAITextEmbedder expects an OpenAI API key"):
OpenAITextEmbedder.from_dict(data)
@pytest.mark.unit
def test_run(self):
model = "text-similarity-ada-001"
@ -138,7 +103,7 @@ class TestOpenAITextEmbedder:
)
assert len(result["embedding"]) == 1536
assert all([isinstance(x, float) for x in result["embedding"]])
assert all(isinstance(x, float) for x in result["embedding"])
assert result["metadata"] == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}}
@pytest.mark.unit

View File

@ -100,35 +100,6 @@ class TestSentenceTransformersDocumentEmbedder:
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "SentenceTransformersDocumentEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cuda",
"token": None,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": False,
"embedding_separator": " - ",
"metadata_fields_to_embed": ["meta_field"],
},
}
component = SentenceTransformersDocumentEmbedder.from_dict(data)
assert component.model_name_or_path == "model"
assert component.device == "cuda"
assert component.token is None
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is False
assert component.metadata_fields_to_embed == ["meta_field"]
assert component.embedding_separator == " - "
@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -103,31 +103,6 @@ class TestSentenceTransformersTextEmbedder:
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "SentenceTransformersTextEmbedder",
"init_parameters": {
"model_name_or_path": "model",
"device": "cuda",
"token": True,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
"progress_bar": False,
"normalize_embeddings": True,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
assert component.model_name_or_path == "model"
assert component.device == "cuda"
assert component.token is True
assert component.prefix == "prefix"
assert component.suffix == "suffix"
assert component.batch_size == 64
assert component.progress_bar is False
assert component.normalize_embeddings is True
@pytest.mark.unit
@patch(
"haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -59,44 +59,6 @@ class TestLinkContentFetcher:
assert fetcher.retry_attempts == 1
assert fetcher.timeout == 2
@pytest.mark.unit
def test_to_dict(self):
fetcher = LinkContentFetcher()
assert fetcher.to_dict() == {
"type": "LinkContentFetcher",
"init_parameters": {
"raise_on_failure": True,
"user_agents": [DEFAULT_USER_AGENT],
"retry_attempts": 2,
"timeout": 3,
},
}
@pytest.mark.unit
def test_to_dict_with_params(self):
fetcher = LinkContentFetcher(raise_on_failure=False, user_agents=["test"], retry_attempts=1, timeout=2)
assert fetcher.to_dict() == {
"type": "LinkContentFetcher",
"init_parameters": {"raise_on_failure": False, "user_agents": ["test"], "retry_attempts": 1, "timeout": 2},
}
@pytest.mark.unit
def test_from_dict(self):
fetcher = LinkContentFetcher.from_dict(
{
"type": "LinkContentFetcher",
"init_parameters": {
"raise_on_failure": False,
"user_agents": ["test"],
"retry_attempts": 1,
"timeout": 2,
},
}
)
assert fetcher.raise_on_failure is False
assert fetcher.user_agents == ["test"]
assert fetcher.retry_attempts == 1
@pytest.mark.unit
def test_run_text(self):
correct_response = b"Example test response"

View File

@ -21,23 +21,6 @@ class TestAzureOCRDocumentConverter:
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "AzureOCRDocumentConverter",
"init_parameters": {
"api_key": "test_credential_key",
"endpoint": "test_endpoint",
"id_hash_keys": [],
"model_id": "prebuilt-read",
},
}
component = AzureOCRDocumentConverter.from_dict(data)
assert component.endpoint == "test_endpoint"
assert component.api_key == "test_credential_key"
assert component.id_hash_keys == []
assert component.model_id == "prebuilt-read"
@pytest.mark.unit
def test_run(self, preview_samples_path):
with patch("haystack.preview.components.file_converters.azure.DocumentAnalysisClient") as mock_azure_client:

View File

@ -7,24 +7,6 @@ from haystack.preview.dataclasses import ByteStream
class TestHTMLToDocument:
@pytest.mark.unit
def test_to_dict(self):
component = HTMLToDocument()
data = component.to_dict()
assert data == {"type": "HTMLToDocument", "init_parameters": {"id_hash_keys": []}}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = HTMLToDocument(id_hash_keys=["name"])
data = component.to_dict()
assert data == {"type": "HTMLToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
@pytest.mark.unit
def test_from_dict(self):
data = {"type": "HTMLToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
component = HTMLToDocument.from_dict(data)
assert component.id_hash_keys == ["name"]
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""

View File

@ -7,24 +7,6 @@ from haystack.preview.dataclasses import ByteStream
class TestPyPDFToDocument:
@pytest.mark.unit
def test_to_dict(self):
component = PyPDFToDocument()
data = component.to_dict()
assert data == {"type": "PyPDFToDocument", "init_parameters": {"id_hash_keys": []}}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = PyPDFToDocument(id_hash_keys=["name"])
data = component.to_dict()
assert data == {"type": "PyPDFToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
@pytest.mark.unit
def test_from_dict(self):
data = {"type": "PyPDFToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
component = PyPDFToDocument.from_dict(data)
assert component.id_hash_keys == ["name"]
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""

View File

@ -11,66 +11,6 @@ from haystack.preview.components.file_converters.txt import TextFileToDocument
class TestTextfileToDocument:
@pytest.mark.unit
def test_to_dict(self):
component = TextFileToDocument()
data = component.to_dict()
assert data == {
"type": "TextFileToDocument",
"init_parameters": {
"encoding": "utf-8",
"remove_numeric_tables": False,
"numeric_row_threshold": 0.4,
"valid_languages": [],
"id_hash_keys": [],
"progress_bar": True,
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = TextFileToDocument(
encoding="latin-1",
remove_numeric_tables=True,
numeric_row_threshold=0.7,
valid_languages=["en", "de"],
id_hash_keys=["name"],
progress_bar=False,
)
data = component.to_dict()
assert data == {
"type": "TextFileToDocument",
"init_parameters": {
"encoding": "latin-1",
"remove_numeric_tables": True,
"numeric_row_threshold": 0.7,
"valid_languages": ["en", "de"],
"id_hash_keys": ["name"],
"progress_bar": False,
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "TextFileToDocument",
"init_parameters": {
"encoding": "latin-1",
"remove_numeric_tables": True,
"numeric_row_threshold": 0.7,
"valid_languages": ["en", "de"],
"id_hash_keys": ["name"],
"progress_bar": False,
},
}
component = TextFileToDocument.from_dict(data)
assert component.encoding == "latin-1"
assert component.remove_numeric_tables
assert component.numeric_row_threshold == 0.7
assert component.valid_languages == ["en", "de"]
assert component.id_hash_keys == ["name"]
assert not component.progress_bar
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""

View File

@ -6,34 +6,6 @@ from haystack.preview.components.file_converters.tika import TikaDocumentConvert
class TestTikaDocumentConverter:
@pytest.mark.unit
def test_to_dict(self):
component = TikaDocumentConverter()
data = component.to_dict()
assert data == {
"type": "TikaDocumentConverter",
"init_parameters": {"tika_url": "http://localhost:9998/tika", "id_hash_keys": []},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = TikaDocumentConverter(tika_url="http://localhost:1234/tika", id_hash_keys=["text", "category"])
data = component.to_dict()
assert data == {
"type": "TikaDocumentConverter",
"init_parameters": {"tika_url": "http://localhost:1234/tika", "id_hash_keys": ["text", "category"]},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "TikaDocumentConverter",
"init_parameters": {"tika_url": "http://localhost:9998/tika", "id_hash_keys": ["text", "category"]},
}
component = TikaDocumentConverter.from_dict(data)
assert component.tika_url == "http://localhost:9998/tika"
assert component.id_hash_keys == ["text", "category"]
@pytest.mark.unit
def test_run(self):
component = TikaDocumentConverter()

View File

@ -181,31 +181,6 @@ class TestHuggingFaceLocalGenerator:
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "HuggingFaceLocalGenerator",
"init_parameters": {
"pipeline_kwargs": {
"model": "gpt2",
"task": "text-generation",
"token": "test-token",
"device": "cuda:0",
},
"generation_kwargs": {"max_new_tokens": 100, "return_full_text": False},
},
}
component = HuggingFaceLocalGenerator.from_dict(data)
assert component.pipeline_kwargs == {
"model": "gpt2",
"task": "text-generation",
"token": "test-token",
"device": "cuda:0",
}
assert component.generation_kwargs == {"max_new_tokens": 100, "return_full_text": False}
@pytest.mark.unit
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock):

View File

@ -16,61 +16,6 @@ class TestDocumentCleaner:
assert cleaner.remove_substrings is None
assert cleaner.remove_regex is None
@pytest.mark.unit
def test_to_dict(self):
cleaner = DocumentCleaner()
data = cleaner.to_dict()
assert data == {
"type": "DocumentCleaner",
"init_parameters": {
"remove_empty_lines": True,
"remove_extra_whitespaces": True,
"remove_repeated_substrings": False,
"remove_substrings": None,
"remove_regex": None,
},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
cleaner = DocumentCleaner(
remove_empty_lines=False,
remove_extra_whitespaces=False,
remove_repeated_substrings=True,
remove_substrings=["a", "b"],
remove_regex=r"\s\s+",
)
data = cleaner.to_dict()
assert data == {
"type": "DocumentCleaner",
"init_parameters": {
"remove_empty_lines": False,
"remove_extra_whitespaces": False,
"remove_repeated_substrings": True,
"remove_substrings": ["a", "b"],
"remove_regex": r"\s\s+",
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "DocumentCleaner",
"init_parameters": {
"remove_empty_lines": False,
"remove_extra_whitespaces": False,
"remove_repeated_substrings": True,
"remove_substrings": ["a", "b"],
"remove_regex": r"\s\s+",
},
}
cleaner = DocumentCleaner.from_dict(data)
assert cleaner.remove_empty_lines == False
assert cleaner.remove_extra_whitespaces == False
assert cleaner.remove_repeated_substrings == True
assert cleaner.remove_substrings == ["a", "b"]
assert cleaner.remove_regex == r"\s\s+"
@pytest.mark.unit
def test_non_text_document(self, caplog):
with caplog.at_level(logging.WARNING):

View File

@ -118,35 +118,6 @@ class TestTextDocumentSplitter:
assert result["documents"][0].text == "This is a text with some words. There is a "
assert result["documents"][1].text == "is a second sentence. And there is a third sentence."
@pytest.mark.unit
def test_to_dict(self):
splitter = TextDocumentSplitter()
data = splitter.to_dict()
assert data == {
"type": "TextDocumentSplitter",
"init_parameters": {"split_by": "word", "split_length": 200, "split_overlap": 0},
}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
splitter = TextDocumentSplitter(split_by="passage", split_length=100, split_overlap=1)
data = splitter.to_dict()
assert data == {
"type": "TextDocumentSplitter",
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "TextDocumentSplitter",
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1},
}
splitter = TextDocumentSplitter.from_dict(data)
assert splitter.split_by == "passage"
assert splitter.split_length == 100
assert splitter.split_overlap == 1
@pytest.mark.unit
def test_source_id_stored_in_metadata(self):
splitter = TextDocumentSplitter(split_by="word", split_length=10)

View File

@ -6,18 +6,6 @@ from haystack.preview.components.preprocessors import TextLanguageClassifier
class TestTextLanguageClassifier:
@pytest.mark.unit
def test_to_dict(self):
component = TextLanguageClassifier(languages=["en", "de"])
data = component.to_dict()
assert data == {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
@pytest.mark.unit
def test_from_dict(self):
data = {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
component = TextLanguageClassifier.from_dict(data)
assert component.languages == ["en", "de"]
@pytest.mark.unit
def test_non_string_input(self):
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):

View File

@ -33,19 +33,6 @@ class TestSimilarityRanker:
},
}
@pytest.mark.integration
def test_from_dict(self):
data = {
"type": "SimilarityRanker",
"init_parameters": {
"device": "cpu",
"top_k": 10,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
},
}
component = SimilarityRanker.from_dict(data)
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"
@pytest.mark.integration
@pytest.mark.parametrize(
"query,docs_before_texts,expected_first_text",

View File

@ -110,39 +110,6 @@ def test_to_dict():
}
@pytest.mark.unit
def test_from_dict():
data = {
"type": "ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"device": "cpu",
"token": None,
"top_k": 30,
"confidence_threshold": 0.5,
"max_seq_length": 300,
"stride": 100,
"max_batch_size": 20,
"answers_per_seq": 5,
"no_answer": False,
"calibration_factor": 0.5,
},
}
component = ExtractiveReader.from_dict(data)
assert component.model_name_or_path == "my-model"
assert component.device == "cpu"
assert component.token is None
assert component.top_k == 30
assert component.confidence_threshold == 0.5
assert component.max_seq_length == 300
assert component.stride == 100
assert component.max_batch_size == 20
assert component.answers_per_seq == 5
assert component.no_answer is False
assert component.calibration_factor == 0.5
@pytest.mark.unit
def test_output(mock_reader: ExtractiveReader):
answers = mock_reader.run(example_queries[0], example_documents[0], top_k=3)[

View File

@ -11,24 +11,6 @@ from haystack.preview.dataclasses import ByteStream
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileTypeRouter:
@pytest.mark.unit
def test_to_dict(self):
component = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
data = component.to_dict()
assert data == {
"type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "FileTypeRouter",
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
}
component = FileTypeRouter.from_dict(data)
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
@pytest.mark.unit
def test_run(self, preview_samples_path):
"""

View File

@ -5,24 +5,6 @@ from haystack.preview.components.routers.metadata_router import MetadataRouter
class TestMetadataRouter:
@pytest.mark.unit
def test_to_dict(self):
component = MetadataRouter(rules={"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}})
data = component.to_dict()
assert data == {
"type": "MetadataRouter",
"init_parameters": {"rules": {"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}}},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "MetadataRouter",
"init_parameters": {"rules": {"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}}},
}
component = MetadataRouter.from_dict(data)
assert component.rules == {"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}}
@pytest.mark.unit
def test_run(self):
rules = {

View File

@ -7,24 +7,6 @@ from haystack.preview.components.samplers.top_p import TopPSampler
class TestTopPSampler:
@pytest.mark.unit
def test_to_dict(self):
component = TopPSampler()
data = component.to_dict()
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 1.0, "score_field": None}}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = TopPSampler(top_p=0.92)
data = component.to_dict()
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 0.92, "score_field": None}}
@pytest.mark.unit
def test_from_dict(self):
data = {"type": "TopPSampler", "init_parameters": {"top_p": 0.9, "score_field": None}}
component = TopPSampler.from_dict(data)
assert component.top_p == 0.9
@pytest.mark.unit
def test_run_scores_from_metadata(self):
"""

View File

@ -124,23 +124,6 @@ class TestSerperDevSearchAPI:
},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "SerperDevWebSearch",
"init_parameters": {
"api_key": "test_key",
"top_k": 10,
"allowed_domains": ["test.com"],
"search_params": {"param": "test"},
},
}
component = SerperDevWebSearch.from_dict(data)
assert component.api_key == "test_key"
assert component.top_k == 10
assert component.allowed_domains == ["test.com"]
assert component.search_params == {"param": "test"}
@pytest.mark.unit
@pytest.mark.parametrize("top_k", [1, 5, 7])
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):