mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-27 06:58:35 +00:00
refactor: Remove reimplementations of default from_dict/to_dict and corresponding tests in 2.0 (#6108)
* whisper transcriber * remove from/to_dict from builders * remove from/to_dict from embedders * remove from/to_dict from fetcher, file_converters * remove from/to_dict from generators, preprocessors * remove from/to_dict from ranker, reader * remove from/to_dict from router, sampler, websearch * pylint * reno * refactor import * remove unused import
This commit is contained in:
parent
6df077cbb4
commit
9f3b6512be
@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict, ComponentError
|
||||
from haystack.preview import component, Document, default_to_dict, ComponentError
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport("Run 'pip install openai-whisper'") as whisper_import:
|
||||
@ -66,13 +66,6 @@ class LocalWhisperTranscriber:
|
||||
self, model_name_or_path=self.model_name, device=str(self.device), whisper_params=self.whisper_params
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "LocalWhisperTranscriber":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
|
||||
@ -6,7 +6,7 @@ import logging
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.preview.utils import request_with_retry
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, Document, default_to_dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -54,25 +54,6 @@ class RemoteWhisperTranscriber:
|
||||
self.api_base = api_base
|
||||
self.whisper_params = whisper_params or {}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model_name=self.model_name,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
whisper_params=self.whisper_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
@ -145,3 +126,12 @@ class RemoteWhisperTranscriber:
|
||||
|
||||
transcriptions.append(transcription)
|
||||
return transcriptions
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from haystack.preview import component, GeneratedAnswer, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, GeneratedAnswer, Document
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -107,19 +107,6 @@ class AnswerBuilder:
|
||||
|
||||
return {"answers": all_answers}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AnswerBuilder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@staticmethod
|
||||
def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Dict, Any
|
||||
from jinja2 import Template, meta
|
||||
|
||||
from haystack.preview import component
|
||||
from haystack.preview import default_to_dict, default_from_dict
|
||||
from haystack.preview import default_to_dict
|
||||
|
||||
|
||||
@component
|
||||
@ -36,10 +36,6 @@ class PromptBuilder:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return default_to_dict(self, template=self._template_string)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data) -> "PromptBuilder":
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(prompt=str)
|
||||
def run(self, **kwargs):
|
||||
return {"prompt": self.template.render(kwargs)}
|
||||
|
||||
@ -5,7 +5,7 @@ import openai
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, Document, default_to_dict
|
||||
|
||||
|
||||
@component
|
||||
@ -89,13 +89,6 @@ class OpenAIDocumentEmbedder:
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
|
||||
@ -3,7 +3,7 @@ import os
|
||||
|
||||
import openai
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, default_to_dict
|
||||
|
||||
|
||||
@component
|
||||
@ -66,13 +66,6 @@ class OpenAITextEmbedder:
|
||||
self, model_name=self.model_name, organization=self.organization, prefix=self.prefix, suffix=self.suffix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAITextEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(embedding=List[float], metadata=Dict[str, Any])
|
||||
def run(self, text: str):
|
||||
"""Embed a string."""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, Document, default_to_dict
|
||||
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
@ -81,13 +81,6 @@ class SentenceTransformersDocumentEmbedder:
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack.preview import component, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, default_to_dict
|
||||
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
@ -72,13 +72,6 @@ class SentenceTransformersTextEmbedder:
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError
|
||||
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview import component
|
||||
from haystack.preview.dataclasses import ByteStream
|
||||
from haystack.preview.version import __version__
|
||||
|
||||
@ -95,25 +95,6 @@ class LinkContentFetcher:
|
||||
|
||||
self._get_response: Callable = get_response
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
raise_on_failure=self.raise_on_failure,
|
||||
user_agents=self.user_agents,
|
||||
retry_attempts=self.retry_attempts,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "LinkContentFetcher":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(streams=List[ByteStream])
|
||||
def run(self, urls: List[str]):
|
||||
"""
|
||||
|
||||
@ -2,7 +2,7 @@ from pathlib import Path
|
||||
from typing import List, Union, Optional, Dict, Any
|
||||
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, Document, default_to_dict
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install azure-ai-formrecognizer>=3.2.0b2'") as azure_import:
|
||||
@ -84,13 +84,6 @@ class AzureOCRDocumentConverter:
|
||||
self, endpoint=self.endpoint, api_key=self.api_key, model_id=self.model_id, id_hash_keys=self.id_hash_keys
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AzureOCRDocumentConverter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@staticmethod
|
||||
def _convert_azure_result_to_document(
|
||||
result: "AnalyzeResult", id_hash_keys: List[str], file_suffix: str
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.preview import Document, component, default_to_dict, default_from_dict
|
||||
from haystack.preview import Document, component
|
||||
from haystack.preview.dataclasses import ByteStream
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
@ -27,15 +27,6 @@ class HTMLToDocument:
|
||||
boilerpy3_import.check()
|
||||
self.id_hash_keys = id_hash_keys or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize the component to a dictionary."""
|
||||
return default_to_dict(self, id_hash_keys=self.id_hash_keys)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HTMLToDocument":
|
||||
"""Deserialize the component from a dictionary."""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, sources: List[Union[str, Path, ByteStream]]):
|
||||
"""
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from typing import List, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
from haystack.preview.dataclasses import ByteStream
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
from haystack.preview import Document, component, default_to_dict, default_from_dict
|
||||
from haystack.preview import Document, component
|
||||
|
||||
with LazyImport("Run 'pip install pypdf'") as pypdf_import:
|
||||
from pypdf import PdfReader
|
||||
@ -30,22 +30,6 @@ class PyPDFToDocument:
|
||||
pypdf_import.check()
|
||||
self.id_hash_keys = id_hash_keys or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
:return: The dictionary containing the component's data.
|
||||
"""
|
||||
return default_to_dict(self, id_hash_keys=self.id_hash_keys)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PyPDFToDocument":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
:param data: The dictionary containing the component's data.
|
||||
:return: The component instance.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, sources: List[Union[str, Path, ByteStream]], id_hash_keys: Optional[List[str]] = None):
|
||||
"""
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.preview import component, Document
|
||||
|
||||
|
||||
with LazyImport("Run 'pip install tika'") as tika_import:
|
||||
@ -70,16 +70,3 @@ class TikaDocumentConverter:
|
||||
logger.error("Could not convert file at '%s' to Document. Error: %s", str(path), e)
|
||||
|
||||
return {"documents": documents}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, tika_url=self.tika_url, id_hash_keys=self.id_hash_keys)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TikaDocumentConverter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Union, Dict, Any
|
||||
from typing import Optional, List, Union, Dict
|
||||
|
||||
from canals.errors import PipelineRuntimeError
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
from haystack.preview import Document, component, default_to_dict, default_from_dict
|
||||
from haystack.preview import Document, component
|
||||
|
||||
with LazyImport("Run 'pip install langdetect'") as langdetect_import:
|
||||
import langdetect
|
||||
@ -61,27 +61,6 @@ class TextFileToDocument:
|
||||
self.id_hash_keys = id_hash_keys or []
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
encoding=self.encoding,
|
||||
remove_numeric_tables=self.remove_numeric_tables,
|
||||
numeric_row_threshold=self.numeric_row_threshold,
|
||||
valid_languages=self.valid_languages,
|
||||
id_hash_keys=self.id_hash_keys,
|
||||
progress_bar=self.progress_bar,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TextFileToDocument":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(
|
||||
self,
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from copy import deepcopy
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview import component, default_to_dict
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
@ -170,13 +170,6 @@ class HuggingFaceLocalGenerator:
|
||||
stop_words=self.stop_words,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalGenerator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(replies=List[str])
|
||||
def run(self, prompt: str):
|
||||
if self.pipeline is None:
|
||||
|
||||
@ -3,9 +3,9 @@ import re
|
||||
from copy import deepcopy
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Generator, List, Optional, Set
|
||||
from typing import Generator, List, Optional, Set
|
||||
|
||||
from haystack.preview import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.preview import Document, component
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -91,26 +91,6 @@ class DocumentCleaner:
|
||||
|
||||
return {"documents": cleaned_docs}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
remove_empty_lines=self.remove_empty_lines,
|
||||
remove_extra_whitespaces=self.remove_extra_whitespaces,
|
||||
remove_repeated_substrings=self.remove_repeated_substrings,
|
||||
remove_substrings=self.remove_substrings,
|
||||
remove_regex=self.remove_regex,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DocumentCleaner":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _remove_empty_lines(self, text: str) -> str:
|
||||
"""
|
||||
Remove empty lines and lines that contain nothing but whitespaces from text.
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from typing import List, Dict, Any, Literal
|
||||
from typing import List, Literal
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from haystack.preview import component, Document, default_from_dict, default_to_dict
|
||||
from haystack.preview import component, Document
|
||||
|
||||
|
||||
@component
|
||||
@ -61,21 +61,6 @@ class TextDocumentSplitter:
|
||||
split_docs += [Document(text=txt, metadata=metadata, id_hash_keys=id_hash_keys) for txt in text_splits]
|
||||
return {"documents": split_docs}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self, split_by=self.split_by, split_length=self.split_length, split_overlap=self.split_overlap
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TextDocumentSplitter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage"]) -> List[str]:
|
||||
if split_by == "passage":
|
||||
split_at = "\n\n"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview import component
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -63,19 +63,6 @@ class TextLanguageClassifier:
|
||||
|
||||
return output
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, languages=self.languages)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TextLanguageClassifier":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def detect_language(self, text: str) -> Optional[str]:
|
||||
try:
|
||||
language = langdetect.detect(text)
|
||||
|
||||
@ -2,7 +2,7 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Dict, Any, Optional
|
||||
|
||||
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack.preview import ComponentError, Document, component, default_to_dict
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -89,13 +89,6 @@ class SimilarityRanker:
|
||||
top_k=self.top_k,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import math
|
||||
import warnings
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict, ComponentError, Document, ExtractedAnswer
|
||||
from haystack.preview import component, default_to_dict, ComponentError, Document, ExtractedAnswer
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
with LazyImport(
|
||||
@ -97,13 +97,6 @@ class ExtractiveReader:
|
||||
calibration_factor=self.calibration_factor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExtractiveReader":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
if self.model is None:
|
||||
if torch.cuda.is_available():
|
||||
|
||||
@ -2,9 +2,9 @@ import logging
|
||||
import mimetypes
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Optional, Dict, Any
|
||||
from typing import List, Union, Optional, Dict
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict
|
||||
from haystack.preview import component
|
||||
from haystack.preview.dataclasses import ByteStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -42,19 +42,6 @@ class FileTypeRouter:
|
||||
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
|
||||
self.mime_types = mime_types
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, mime_types=self.mime_types)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "FileTypeRouter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def run(self, sources: List[Union[str, Path, ByteStream]]) -> Dict[str, List[Union[ByteStream, Path]]]:
|
||||
"""
|
||||
Categorizes the provided data sources by their MIME types.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, List
|
||||
|
||||
from haystack.preview import component, default_from_dict, default_to_dict, Document
|
||||
from haystack.preview import component, Document
|
||||
from haystack.preview.utils.filters import document_matches_filter
|
||||
|
||||
|
||||
@ -52,16 +52,3 @@ class MetadataRouter:
|
||||
|
||||
output["unmatched"] = unmatched_documents
|
||||
return output
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, rules=self.rules)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "MetadataRouter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional
|
||||
|
||||
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
|
||||
from haystack.preview import ComponentError, Document, component
|
||||
from haystack.preview.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -48,19 +48,6 @@ class TopPSampler:
|
||||
self.top_p = top_p
|
||||
self.score_field = score_field
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, top_p=self.top_p, score_field=self.score_field)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TopPSampler":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, documents: List[Document], top_p: Optional[float] = None):
|
||||
"""
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Any
|
||||
|
||||
import requests
|
||||
|
||||
from haystack.preview import Document, component, default_from_dict, default_to_dict, ComponentError
|
||||
from haystack.preview import Document, component, default_to_dict, ComponentError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -58,13 +58,6 @@ class SerperDevWebSearch:
|
||||
search_params=self.search_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document], links=List[str])
|
||||
def run(self, query: str):
|
||||
"""
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
preview:
|
||||
- |
|
||||
Removed implementations of from_dict and to_dict from all components where they had the same effect as the default implementation from Canals: https://github.com/deepset-ai/canals/blob/main/canals/serialization.py#L12-L13 This refactoring does not change the behavior of the components.
|
||||
@ -53,21 +53,6 @@ class TestLocalWhisperTranscriber:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "LocalWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "tiny",
|
||||
"device": "cuda",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
transcriber = LocalWhisperTranscriber.from_dict(data)
|
||||
assert transcriber.model_name == "tiny"
|
||||
assert transcriber.device == torch.device("cuda")
|
||||
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_warmup(self):
|
||||
with patch("haystack.preview.components.audio.whisper_local.whisper") as mocked_whisper:
|
||||
|
||||
@ -33,7 +33,6 @@ class TestRemoteWhisperTranscriber:
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_key": "test",
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"whisper_params": {},
|
||||
},
|
||||
@ -52,29 +51,11 @@ class TestRemoteWhisperTranscriber:
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_key": "test",
|
||||
"api_base": "https://my.api.base/something_else/v3",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"model_name": "whisper-1",
|
||||
"api_key": "test",
|
||||
"api_base": "https://my.api.base/something_else/v3",
|
||||
"whisper_params": {"return_segments": True, "temperature": [0.1, 0.6, 0.8]},
|
||||
},
|
||||
}
|
||||
transcriber = RemoteWhisperTranscriber.from_dict(data)
|
||||
assert transcriber.model_name == "whisper-1"
|
||||
assert transcriber.api_key == "test"
|
||||
assert transcriber.api_base == "https://my.api.base/something_else/v3"
|
||||
assert transcriber.whisper_params == {"return_segments": True, "temperature": [0.1, 0.6, 0.8]}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_with_path(self, preview_samples_path):
|
||||
mock_response = MagicMock()
|
||||
|
||||
@ -7,31 +7,6 @@ from haystack.preview.components.builders.answer_builder import AnswerBuilder
|
||||
|
||||
|
||||
class TestAnswerBuilder:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = AnswerBuilder()
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "AnswerBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = AnswerBuilder(pattern="pattern", reference_pattern="reference_pattern")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "AnswerBuilder",
|
||||
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "AnswerBuilder",
|
||||
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
|
||||
}
|
||||
component = AnswerBuilder.from_dict(data)
|
||||
assert component.pattern == "pattern"
|
||||
assert component.reference_pattern == "reference_pattern"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_unmatching_input_len(self):
|
||||
component = AnswerBuilder()
|
||||
|
||||
@ -16,13 +16,6 @@ def test_to_dict():
|
||||
assert res == {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict():
|
||||
data = {"type": "PromptBuilder", "init_parameters": {"template": "This is a {{ variable }}"}}
|
||||
builder = PromptBuilder.from_dict(data)
|
||||
builder._template_string == "This is a {{ variable }}"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run():
|
||||
builder = PromptBuilder(template="This is a {{ variable }}")
|
||||
|
||||
@ -118,53 +118,6 @@ class TestOpenAIDocumentEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "my-org",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"metadata_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
component = OpenAIDocumentEmbedder.from_dict(data)
|
||||
assert openai.api_key == "fake-api-key"
|
||||
assert component.model_name == "model"
|
||||
assert component.organization == "my-org"
|
||||
assert openai.organization == "my-org"
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.metadata_fields_to_embed == ["test_field"]
|
||||
assert component.embedding_separator == " | "
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "my-org",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"metadata_fields_to_embed": ["test_field"],
|
||||
"embedding_separator": " | ",
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="OpenAIDocumentEmbedder expects an OpenAI API key"):
|
||||
OpenAIDocumentEmbedder.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_texts_to_embed_w_metadata(self):
|
||||
documents = [
|
||||
|
||||
@ -86,41 +86,6 @@ class TestOpenAITextEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "OpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "fake-organization",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
},
|
||||
}
|
||||
component = OpenAITextEmbedder.from_dict(data)
|
||||
assert openai.api_key == "fake-api-key"
|
||||
assert component.model_name == "model"
|
||||
assert component.organization == "fake-organization"
|
||||
assert openai.organization == "fake-organization"
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "OpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name": "model",
|
||||
"organization": "fake-organization",
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError, match="OpenAITextEmbedder expects an OpenAI API key"):
|
||||
OpenAITextEmbedder.from_dict(data)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
model = "text-similarity-ada-001"
|
||||
@ -138,7 +103,7 @@ class TestOpenAITextEmbedder:
|
||||
)
|
||||
|
||||
assert len(result["embedding"]) == 1536
|
||||
assert all([isinstance(x, float) for x in result["embedding"]])
|
||||
assert all(isinstance(x, float) for x in result["embedding"])
|
||||
assert result["metadata"] == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}}
|
||||
|
||||
@pytest.mark.unit
|
||||
|
||||
@ -100,35 +100,6 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "SentenceTransformersDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"token": None,
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": False,
|
||||
"embedding_separator": " - ",
|
||||
"metadata_fields_to_embed": ["meta_field"],
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersDocumentEmbedder.from_dict(data)
|
||||
assert component.model_name_or_path == "model"
|
||||
assert component.device == "cuda"
|
||||
assert component.token is None
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.normalize_embeddings is False
|
||||
assert component.metadata_fields_to_embed == ["meta_field"]
|
||||
assert component.embedding_separator == " - "
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch(
|
||||
"haystack.preview.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
|
||||
@ -103,31 +103,6 @@ class TestSentenceTransformersTextEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "model",
|
||||
"device": "cuda",
|
||||
"token": True,
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
"progress_bar": False,
|
||||
"normalize_embeddings": True,
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder.from_dict(data)
|
||||
assert component.model_name_or_path == "model"
|
||||
assert component.device == "cuda"
|
||||
assert component.token is True
|
||||
assert component.prefix == "prefix"
|
||||
assert component.suffix == "suffix"
|
||||
assert component.batch_size == 64
|
||||
assert component.progress_bar is False
|
||||
assert component.normalize_embeddings is True
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch(
|
||||
"haystack.preview.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
|
||||
@ -59,44 +59,6 @@ class TestLinkContentFetcher:
|
||||
assert fetcher.retry_attempts == 1
|
||||
assert fetcher.timeout == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
fetcher = LinkContentFetcher()
|
||||
assert fetcher.to_dict() == {
|
||||
"type": "LinkContentFetcher",
|
||||
"init_parameters": {
|
||||
"raise_on_failure": True,
|
||||
"user_agents": [DEFAULT_USER_AGENT],
|
||||
"retry_attempts": 2,
|
||||
"timeout": 3,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_params(self):
|
||||
fetcher = LinkContentFetcher(raise_on_failure=False, user_agents=["test"], retry_attempts=1, timeout=2)
|
||||
assert fetcher.to_dict() == {
|
||||
"type": "LinkContentFetcher",
|
||||
"init_parameters": {"raise_on_failure": False, "user_agents": ["test"], "retry_attempts": 1, "timeout": 2},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
fetcher = LinkContentFetcher.from_dict(
|
||||
{
|
||||
"type": "LinkContentFetcher",
|
||||
"init_parameters": {
|
||||
"raise_on_failure": False,
|
||||
"user_agents": ["test"],
|
||||
"retry_attempts": 1,
|
||||
"timeout": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
assert fetcher.raise_on_failure is False
|
||||
assert fetcher.user_agents == ["test"]
|
||||
assert fetcher.retry_attempts == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_text(self):
|
||||
correct_response = b"Example test response"
|
||||
|
||||
@ -21,23 +21,6 @@ class TestAzureOCRDocumentConverter:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "AzureOCRDocumentConverter",
|
||||
"init_parameters": {
|
||||
"api_key": "test_credential_key",
|
||||
"endpoint": "test_endpoint",
|
||||
"id_hash_keys": [],
|
||||
"model_id": "prebuilt-read",
|
||||
},
|
||||
}
|
||||
component = AzureOCRDocumentConverter.from_dict(data)
|
||||
assert component.endpoint == "test_endpoint"
|
||||
assert component.api_key == "test_credential_key"
|
||||
assert component.id_hash_keys == []
|
||||
assert component.model_id == "prebuilt-read"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
with patch("haystack.preview.components.file_converters.azure.DocumentAnalysisClient") as mock_azure_client:
|
||||
|
||||
@ -7,24 +7,6 @@ from haystack.preview.dataclasses import ByteStream
|
||||
|
||||
|
||||
class TestHTMLToDocument:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = HTMLToDocument()
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "HTMLToDocument", "init_parameters": {"id_hash_keys": []}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = HTMLToDocument(id_hash_keys=["name"])
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "HTMLToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {"type": "HTMLToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
|
||||
component = HTMLToDocument.from_dict(data)
|
||||
assert component.id_hash_keys == ["name"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
|
||||
@ -7,24 +7,6 @@ from haystack.preview.dataclasses import ByteStream
|
||||
|
||||
|
||||
class TestPyPDFToDocument:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = PyPDFToDocument()
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "PyPDFToDocument", "init_parameters": {"id_hash_keys": []}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = PyPDFToDocument(id_hash_keys=["name"])
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "PyPDFToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {"type": "PyPDFToDocument", "init_parameters": {"id_hash_keys": ["name"]}}
|
||||
component = PyPDFToDocument.from_dict(data)
|
||||
assert component.id_hash_keys == ["name"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
|
||||
@ -11,66 +11,6 @@ from haystack.preview.components.file_converters.txt import TextFileToDocument
|
||||
|
||||
|
||||
class TestTextfileToDocument:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TextFileToDocument()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "TextFileToDocument",
|
||||
"init_parameters": {
|
||||
"encoding": "utf-8",
|
||||
"remove_numeric_tables": False,
|
||||
"numeric_row_threshold": 0.4,
|
||||
"valid_languages": [],
|
||||
"id_hash_keys": [],
|
||||
"progress_bar": True,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = TextFileToDocument(
|
||||
encoding="latin-1",
|
||||
remove_numeric_tables=True,
|
||||
numeric_row_threshold=0.7,
|
||||
valid_languages=["en", "de"],
|
||||
id_hash_keys=["name"],
|
||||
progress_bar=False,
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "TextFileToDocument",
|
||||
"init_parameters": {
|
||||
"encoding": "latin-1",
|
||||
"remove_numeric_tables": True,
|
||||
"numeric_row_threshold": 0.7,
|
||||
"valid_languages": ["en", "de"],
|
||||
"id_hash_keys": ["name"],
|
||||
"progress_bar": False,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "TextFileToDocument",
|
||||
"init_parameters": {
|
||||
"encoding": "latin-1",
|
||||
"remove_numeric_tables": True,
|
||||
"numeric_row_threshold": 0.7,
|
||||
"valid_languages": ["en", "de"],
|
||||
"id_hash_keys": ["name"],
|
||||
"progress_bar": False,
|
||||
},
|
||||
}
|
||||
component = TextFileToDocument.from_dict(data)
|
||||
assert component.encoding == "latin-1"
|
||||
assert component.remove_numeric_tables
|
||||
assert component.numeric_row_threshold == 0.7
|
||||
assert component.valid_languages == ["en", "de"]
|
||||
assert component.id_hash_keys == ["name"]
|
||||
assert not component.progress_bar
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
|
||||
@ -6,34 +6,6 @@ from haystack.preview.components.file_converters.tika import TikaDocumentConvert
|
||||
|
||||
|
||||
class TestTikaDocumentConverter:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TikaDocumentConverter()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "TikaDocumentConverter",
|
||||
"init_parameters": {"tika_url": "http://localhost:9998/tika", "id_hash_keys": []},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = TikaDocumentConverter(tika_url="http://localhost:1234/tika", id_hash_keys=["text", "category"])
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "TikaDocumentConverter",
|
||||
"init_parameters": {"tika_url": "http://localhost:1234/tika", "id_hash_keys": ["text", "category"]},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "TikaDocumentConverter",
|
||||
"init_parameters": {"tika_url": "http://localhost:9998/tika", "id_hash_keys": ["text", "category"]},
|
||||
}
|
||||
component = TikaDocumentConverter.from_dict(data)
|
||||
assert component.tika_url == "http://localhost:9998/tika"
|
||||
assert component.id_hash_keys == ["text", "category"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
component = TikaDocumentConverter()
|
||||
|
||||
@ -181,31 +181,6 @@ class TestHuggingFaceLocalGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"pipeline_kwargs": {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
"token": "test-token",
|
||||
"device": "cuda:0",
|
||||
},
|
||||
"generation_kwargs": {"max_new_tokens": 100, "return_full_text": False},
|
||||
},
|
||||
}
|
||||
|
||||
component = HuggingFaceLocalGenerator.from_dict(data)
|
||||
|
||||
assert component.pipeline_kwargs == {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
"token": "test-token",
|
||||
"device": "cuda:0",
|
||||
}
|
||||
assert component.generation_kwargs == {"max_new_tokens": 100, "return_full_text": False}
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch("haystack.preview.components.generators.hugging_face.hugging_face_local.pipeline")
|
||||
def test_warm_up(self, pipeline_mock):
|
||||
|
||||
@ -16,61 +16,6 @@ class TestDocumentCleaner:
|
||||
assert cleaner.remove_substrings is None
|
||||
assert cleaner.remove_regex is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
cleaner = DocumentCleaner()
|
||||
data = cleaner.to_dict()
|
||||
assert data == {
|
||||
"type": "DocumentCleaner",
|
||||
"init_parameters": {
|
||||
"remove_empty_lines": True,
|
||||
"remove_extra_whitespaces": True,
|
||||
"remove_repeated_substrings": False,
|
||||
"remove_substrings": None,
|
||||
"remove_regex": None,
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
cleaner = DocumentCleaner(
|
||||
remove_empty_lines=False,
|
||||
remove_extra_whitespaces=False,
|
||||
remove_repeated_substrings=True,
|
||||
remove_substrings=["a", "b"],
|
||||
remove_regex=r"\s\s+",
|
||||
)
|
||||
data = cleaner.to_dict()
|
||||
assert data == {
|
||||
"type": "DocumentCleaner",
|
||||
"init_parameters": {
|
||||
"remove_empty_lines": False,
|
||||
"remove_extra_whitespaces": False,
|
||||
"remove_repeated_substrings": True,
|
||||
"remove_substrings": ["a", "b"],
|
||||
"remove_regex": r"\s\s+",
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "DocumentCleaner",
|
||||
"init_parameters": {
|
||||
"remove_empty_lines": False,
|
||||
"remove_extra_whitespaces": False,
|
||||
"remove_repeated_substrings": True,
|
||||
"remove_substrings": ["a", "b"],
|
||||
"remove_regex": r"\s\s+",
|
||||
},
|
||||
}
|
||||
cleaner = DocumentCleaner.from_dict(data)
|
||||
assert cleaner.remove_empty_lines == False
|
||||
assert cleaner.remove_extra_whitespaces == False
|
||||
assert cleaner.remove_repeated_substrings == True
|
||||
assert cleaner.remove_substrings == ["a", "b"]
|
||||
assert cleaner.remove_regex == r"\s\s+"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_text_document(self, caplog):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
|
||||
@ -118,35 +118,6 @@ class TestTextDocumentSplitter:
|
||||
assert result["documents"][0].text == "This is a text with some words. There is a "
|
||||
assert result["documents"][1].text == "is a second sentence. And there is a third sentence."
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
splitter = TextDocumentSplitter()
|
||||
data = splitter.to_dict()
|
||||
assert data == {
|
||||
"type": "TextDocumentSplitter",
|
||||
"init_parameters": {"split_by": "word", "split_length": 200, "split_overlap": 0},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
splitter = TextDocumentSplitter(split_by="passage", split_length=100, split_overlap=1)
|
||||
data = splitter.to_dict()
|
||||
assert data == {
|
||||
"type": "TextDocumentSplitter",
|
||||
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "TextDocumentSplitter",
|
||||
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1},
|
||||
}
|
||||
splitter = TextDocumentSplitter.from_dict(data)
|
||||
assert splitter.split_by == "passage"
|
||||
assert splitter.split_length == 100
|
||||
assert splitter.split_overlap == 1
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_source_id_stored_in_metadata(self):
|
||||
splitter = TextDocumentSplitter(split_by="word", split_length=10)
|
||||
|
||||
@ -6,18 +6,6 @@ from haystack.preview.components.preprocessors import TextLanguageClassifier
|
||||
|
||||
|
||||
class TestTextLanguageClassifier:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TextLanguageClassifier(languages=["en", "de"])
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {"type": "TextLanguageClassifier", "init_parameters": {"languages": ["en", "de"]}}
|
||||
component = TextLanguageClassifier.from_dict(data)
|
||||
assert component.languages == ["en", "de"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_non_string_input(self):
|
||||
with pytest.raises(TypeError, match="TextLanguageClassifier expects a str as input."):
|
||||
|
||||
@ -33,19 +33,6 @@ class TestSimilarityRanker:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "SimilarityRanker",
|
||||
"init_parameters": {
|
||||
"device": "cpu",
|
||||
"top_k": 10,
|
||||
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
},
|
||||
}
|
||||
component = SimilarityRanker.from_dict(data)
|
||||
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize(
|
||||
"query,docs_before_texts,expected_first_text",
|
||||
|
||||
@ -110,39 +110,6 @@ def test_to_dict():
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict():
|
||||
data = {
|
||||
"type": "ExtractiveReader",
|
||||
"init_parameters": {
|
||||
"model_name_or_path": "my-model",
|
||||
"device": "cpu",
|
||||
"token": None,
|
||||
"top_k": 30,
|
||||
"confidence_threshold": 0.5,
|
||||
"max_seq_length": 300,
|
||||
"stride": 100,
|
||||
"max_batch_size": 20,
|
||||
"answers_per_seq": 5,
|
||||
"no_answer": False,
|
||||
"calibration_factor": 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
component = ExtractiveReader.from_dict(data)
|
||||
assert component.model_name_or_path == "my-model"
|
||||
assert component.device == "cpu"
|
||||
assert component.token is None
|
||||
assert component.top_k == 30
|
||||
assert component.confidence_threshold == 0.5
|
||||
assert component.max_seq_length == 300
|
||||
assert component.stride == 100
|
||||
assert component.max_batch_size == 20
|
||||
assert component.answers_per_seq == 5
|
||||
assert component.no_answer is False
|
||||
assert component.calibration_factor == 0.5
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_output(mock_reader: ExtractiveReader):
|
||||
answers = mock_reader.run(example_queries[0], example_documents[0], top_k=3)[
|
||||
|
||||
@ -11,24 +11,6 @@ from haystack.preview.dataclasses import ByteStream
|
||||
reason="Can't run on Windows Github CI, need access to registry to get mime types",
|
||||
)
|
||||
class TestFileTypeRouter:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = FileTypeRouter(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "FileTypeRouter",
|
||||
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "FileTypeRouter",
|
||||
"init_parameters": {"mime_types": ["text/plain", "audio/x-wav", "image/jpeg"]},
|
||||
}
|
||||
component = FileTypeRouter.from_dict(data)
|
||||
assert component.mime_types == ["text/plain", "audio/x-wav", "image/jpeg"]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self, preview_samples_path):
|
||||
"""
|
||||
|
||||
@ -5,24 +5,6 @@ from haystack.preview.components.routers.metadata_router import MetadataRouter
|
||||
|
||||
|
||||
class TestMetadataRouter:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = MetadataRouter(rules={"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}})
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "MetadataRouter",
|
||||
"init_parameters": {"rules": {"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}}},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "MetadataRouter",
|
||||
"init_parameters": {"rules": {"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}}},
|
||||
}
|
||||
component = MetadataRouter.from_dict(data)
|
||||
assert component.rules == {"edge_1": {"created_at": {"$gte": "2023-01-01", "$lt": "2023-04-01"}}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(self):
|
||||
rules = {
|
||||
|
||||
@ -7,24 +7,6 @@ from haystack.preview.components.samplers.top_p import TopPSampler
|
||||
|
||||
|
||||
class TestTopPSampler:
|
||||
@pytest.mark.unit
|
||||
def test_to_dict(self):
|
||||
component = TopPSampler()
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 1.0, "score_field": None}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = TopPSampler(top_p=0.92)
|
||||
data = component.to_dict()
|
||||
assert data == {"type": "TopPSampler", "init_parameters": {"top_p": 0.92, "score_field": None}}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {"type": "TopPSampler", "init_parameters": {"top_p": 0.9, "score_field": None}}
|
||||
component = TopPSampler.from_dict(data)
|
||||
assert component.top_p == 0.9
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_scores_from_metadata(self):
|
||||
"""
|
||||
|
||||
@ -124,23 +124,6 @@ class TestSerperDevSearchAPI:
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "SerperDevWebSearch",
|
||||
"init_parameters": {
|
||||
"api_key": "test_key",
|
||||
"top_k": 10,
|
||||
"allowed_domains": ["test.com"],
|
||||
"search_params": {"param": "test"},
|
||||
},
|
||||
}
|
||||
component = SerperDevWebSearch.from_dict(data)
|
||||
assert component.api_key == "test_key"
|
||||
assert component.top_k == 10
|
||||
assert component.allowed_domains == ["test.com"]
|
||||
assert component.search_params == {"param": "test"}
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 7])
|
||||
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user