feat!: Use Secret for passing authentication secrets to components (#6887)

* feat!: Use `Secret` for passing authentication secrets to components

* Add comment to clarify type ignore
This commit is contained in:
Madeesh Kannan 2024-02-05 13:17:01 +01:00 committed by GitHub
parent 393a7993c3
commit 27d1af3068
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 707 additions and 421 deletions

View File

@ -7,6 +7,7 @@ from openai import OpenAI
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.dataclasses import ByteStream
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -24,7 +25,7 @@ class RemoteWhisperTranscriber:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "whisper-1",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
@ -61,6 +62,7 @@ class RemoteWhisperTranscriber:
self.organization = organization
self.model = model
self.api_base_url = api_base_url
self.api_key = api_key
# Only response_format = "json" is supported
whisper_params = kwargs
@ -71,7 +73,7 @@ class RemoteWhisperTranscriber:
)
whisper_params["response_format"] = "json"
self.whisper_params = whisper_params
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def to_dict(self) -> Dict[str, Any]:
"""
@ -81,6 +83,7 @@ class RemoteWhisperTranscriber:
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model,
organization=self.organization,
api_base_url=self.api_base_url,
@ -92,6 +95,7 @@ class RemoteWhisperTranscriber:
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])

View File

@ -29,10 +29,11 @@ class DynamicChatPromptBuilder:
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack import Pipeline
from haystack.utils import Secret
# no parameter init, we don't use any runtime template variables
prompt_builder = DynamicChatPromptBuilder()
llm = OpenAIChatGenerator(api_key="<your-api-key>", model="gpt-3.5-turbo")
llm = OpenAIChatGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder)

View File

@ -21,9 +21,10 @@ class DynamicPromptBuilder:
from haystack.components.builders import DynamicPromptBuilder
from haystack.components.generators import OpenAIGenerator
from haystack import Pipeline, component, Document
from haystack.utils import Secret
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
llm = OpenAIGenerator(api_key="<your-api-key>", model="gpt-3.5-turbo")
llm = OpenAIGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
@component

View File

@ -1,12 +1,12 @@
from pathlib import Path
from typing import List, Union, Dict, Any, Optional
import logging
import os
from haystack.lazy_imports import LazyImport
from haystack import component, Document, default_to_dict
from haystack import component, Document, default_to_dict, default_from_dict
from haystack.dataclasses import ByteStream
from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -29,8 +29,9 @@ class AzureOCRDocumentConverter:
Usage example:
```python
from haystack.components.converters.azure import AzureOCRDocumentConverter
from haystack.utils import Secret
converter = AzureOCRDocumentConverter()
converter = AzureOCRDocumentConverter(endpoint="<url>", api_key=Secret.from_token("<your-api-key>"))
results = converter.run(sources=["image-based-document.pdf"], meta={"date_added": datetime.now().isoformat()})
documents = results["documents"]
print(documents[0].content)
@ -38,33 +39,23 @@ class AzureOCRDocumentConverter:
```
"""
def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"):
def __init__(
self, endpoint: str, api_key: Secret = Secret.from_env_var("AZURE_AI_API_KEY"), model_id: str = "prebuilt-read"
):
"""
Create an AzureOCRDocumentConverter component.
:param endpoint: The endpoint of your Azure resource.
:param api_key: The key of your Azure resource. It can be
explicitly provided or automatically read from the
environment variable AZURE_AI_API_KEY (recommended).
:param api_key: The key of your Azure resource.
:param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature)
for a list of available models. Default: `"prebuilt-read"`.
"""
azure_import.check()
api_key = api_key or os.environ.get("AZURE_AI_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"AzureOCRDocumentConverter expects an API key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint, credential=AzureKeyCredential(api_key)
)
self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint, credential=AzureKeyCredential(api_key.resolve_value())) # type: ignore
self.endpoint = endpoint
self.model_id = model_id
self.api_key = api_key
@component.output_types(documents=List[Document], raw_azure_response=List[Dict])
def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None):
@ -116,7 +107,15 @@ class AzureOCRDocumentConverter:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id)
return default_to_dict(self, api_key=self.api_key.to_dict(), endpoint=self.endpoint, model_id=self.model_id)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOCRDocumentConverter":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@staticmethod
def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: Optional[str] = None) -> Document:

View File

@ -1,10 +1,11 @@
import os
from typing import List, Optional, Dict, Any, Tuple
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from openai.lib.azure import AzureOpenAI
from tqdm import tqdm
from haystack import component, Document, default_to_dict
from haystack import component, Document, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component
@ -34,9 +35,8 @@ class AzureOpenAIDocumentEmbedder:
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
@ -53,8 +53,6 @@ class AzureOpenAIDocumentEmbedder:
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
@ -70,6 +68,11 @@ class AzureOpenAIDocumentEmbedder:
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
self.api_key = api_key
self.azure_ad_token = azure_ad_token
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
@ -85,9 +88,8 @@ class AzureOpenAIDocumentEmbedder:
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
)
@ -98,10 +100,6 @@ class AzureOpenAIDocumentEmbedder:
return {"model": self.azure_deployment}
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self,
azure_endpoint=self.azure_endpoint,
@ -114,8 +112,15 @@ class AzureOpenAIDocumentEmbedder:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.

View File

@ -1,9 +1,10 @@
import os
from typing import List, Optional, Dict, Any
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from openai.lib.azure import AzureOpenAI
from haystack import component, default_to_dict, Document
from haystack import component, Document, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component
@ -32,9 +33,8 @@ class AzureOpenAITextEmbedder:
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
prefix: str = "",
suffix: str = "",
@ -47,8 +47,6 @@ class AzureOpenAITextEmbedder:
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text.
@ -62,6 +60,11 @@ class AzureOpenAITextEmbedder:
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
self.api_key = api_key
self.azure_ad_token = azure_ad_token
self.api_version = api_version
self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment
@ -73,9 +76,8 @@ class AzureOpenAITextEmbedder:
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
)
@ -98,8 +100,15 @@ class AzureOpenAITextEmbedder:
api_version=self.api_version,
prefix=self.prefix,
suffix=self.suffix,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAITextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
return default_from_dict(cls, data)
@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string using AzureOpenAITextEmbedder."""

View File

@ -1,6 +1,7 @@
from typing import List, Optional, Union, Dict
from typing import List, Optional, Dict
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as sentence_transformers_import:
from sentence_transformers import SentenceTransformer
@ -14,14 +15,12 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}
@staticmethod
def get_embedding_backend(model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
embedding_backend_id = f"{model}{device}{use_auth_token}"
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
embedding_backend_id = f"{model}{device}{auth_token}"
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(
model=model, device=device, use_auth_token=use_auth_token
)
embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
@ -31,9 +30,11 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings.
"""
def __init__(self, model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
sentence_transformers_import.check()
self.model = SentenceTransformer(model_name_or_path=model, device=device, use_auth_token=use_auth_token)
self.model = SentenceTransformer(
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
)
def embed(self, data: List[str], **kwargs) -> List[List[float]]:
embeddings = self.model.encode(data, **kwargs).tolist()

View File

@ -1,26 +1,27 @@
from typing import Optional
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
def check_valid_model(model_id: str, token: Optional[str]) -> None:
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a embedding model.
:param model_id: A string representing the HuggingFace model ID.
:param token: An optional string representing the authentication token.
:param token: The optional authentication token.
:raises ValueError: If the model is not found or is not a embedding model.
"""
transformers_import.check()
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e:
raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."

View File

@ -1,14 +1,14 @@
import logging
import os
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from tqdm import tqdm
from haystack import component, default_to_dict
from haystack.components.embedders.hf_utils import check_valid_model
from haystack.dataclasses import Document
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack import component, default_to_dict, default_from_dict
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -29,11 +29,12 @@ class HuggingFaceTEIDocumentEmbedder:
```python
from haystack.dataclasses import Document
from haystack.components.embedders import HuggingFaceTEIDocumentEmbedder
from haystack.utils import Secret
doc = Document(content="I love pizza!")
document_embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", token="<your-token>"
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
)
result = document_embedder.run([doc])
@ -52,7 +53,7 @@ class HuggingFaceTEIDocumentEmbedder:
doc = Document(content="I love pizza!")
document_embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token="<your-token>"
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token=Secret.from_token("<your-api-key>")
)
result = document_embedder.run([doc])
@ -83,7 +84,7 @@ class HuggingFaceTEIDocumentEmbedder:
self,
model: str = "BAAI/bge-small-en-v1.5",
url: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
@ -98,8 +99,7 @@ class HuggingFaceTEIDocumentEmbedder:
:param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
Endpoint.
:param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
a private or gated model. It can be explicitly provided or automatically read from the environment
variable HF_API_TOKEN (recommended).
a private or gated model.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once.
@ -116,15 +116,12 @@ class HuggingFaceTEIDocumentEmbedder:
if not is_valid_url:
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
# The user does not need to provide a token if it is a local server or free public HF Inference Endpoint.
token = token or os.environ.get("HF_API_TOKEN")
check_valid_model(model, token)
self.model = model
self.url = url
self.token = token
self.client = InferenceClient(url or model, token=token)
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
@ -133,10 +130,6 @@ class HuggingFaceTEIDocumentEmbedder:
self.embedding_separator = embedding_separator
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `token` value passed
to the constructor.
"""
return default_to_dict(
self,
model=self.model,
@ -147,8 +140,14 @@ class HuggingFaceTEIDocumentEmbedder:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
token=self.token.to_dict() if self.token else None,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.

View File

@ -1,11 +1,11 @@
import logging
import os
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from haystack import component, default_to_dict
from haystack import component, default_to_dict, default_from_dict
from haystack.components.embedders.hf_utils import check_valid_model
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -23,11 +23,12 @@ class HuggingFaceTEITextEmbedder:
Inference API tier:
```python
from haystack.components.embedders import HuggingFaceTEITextEmbedder
from haystack.utils import Secret
text_to_embed = "I love pizza!"
text_embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", token="<your-token>"
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
)
print(text_embedder.run(text_to_embed))
@ -44,7 +45,7 @@ class HuggingFaceTEITextEmbedder:
text_to_embed = "I love pizza!"
text_embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token="<your-token>"
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token=Secret.from_token("<your-api-key>")
)
print(text_embedder.run(text_to_embed))
@ -74,7 +75,7 @@ class HuggingFaceTEITextEmbedder:
self,
model: str = "BAAI/bge-small-en-v1.5",
url: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
):
@ -85,8 +86,7 @@ class HuggingFaceTEITextEmbedder:
:param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
Endpoint.
:param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
a private or gated model. It can be explicitly provided or automatically read from the environment
variable HF_API_TOKEN (recommended).
a private or gated model.
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
@ -98,24 +98,29 @@ class HuggingFaceTEITextEmbedder:
if not is_valid_url:
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
# The user does not need to provide a token if it is a local server or free public HF Inference Endpoint.
token = token or os.environ.get("HF_API_TOKEN")
check_valid_model(model, token)
self.model = model
self.url = url
self.token = token
self.client = InferenceClient(url or model, token=token)
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix
self.suffix = suffix
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `token` value passed
to the constructor.
"""
return default_to_dict(self, model=self.model, url=self.url, prefix=self.prefix, suffix=self.suffix)
return default_to_dict(
self,
model=self.model,
url=self.url,
prefix=self.prefix,
suffix=self.suffix,
token=self.token.to_dict() if self.token else None,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEITextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""

View File

@ -3,7 +3,8 @@ from typing import List, Optional, Dict, Any, Tuple
from openai import OpenAI
from tqdm import tqdm
from haystack import component, Document, default_to_dict
from haystack import component, Document, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component
@ -30,7 +31,7 @@ class OpenAIDocumentEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
@ -43,8 +44,7 @@ class OpenAIDocumentEmbedder:
):
"""
Create a OpenAIDocumentEmbedder component.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param organization: The Organization ID, defaults to `None`. See
@ -57,6 +57,7 @@ class OpenAIDocumentEmbedder:
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
self.api_key = api_key
self.model = model
self.api_base_url = api_base_url
self.organization = organization
@ -67,7 +68,7 @@ class OpenAIDocumentEmbedder:
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -91,8 +92,14 @@ class OpenAIDocumentEmbedder:
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
api_key=self.api_key.to_dict(),
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.

View File

@ -2,7 +2,8 @@ from typing import List, Optional, Dict, Any
from openai import OpenAI
from haystack import component, default_to_dict
from haystack import component, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component
@ -28,7 +29,7 @@ class OpenAITextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None,
organization: Optional[str] = None,
@ -38,8 +39,7 @@ class OpenAITextEmbedder:
"""
Create an OpenAITextEmbedder component.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param api_key: The OpenAI API key.
:param model: The name of the OpenAI model to use. For more details on the available models,
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
:param organization: The Organization ID, defaults to `None`. See
@ -52,8 +52,9 @@ class OpenAITextEmbedder:
self.organization = organization
self.prefix = prefix
self.suffix = suffix
self.api_key = api_key
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -62,15 +63,20 @@ class OpenAITextEmbedder:
return {"model": self.model}
def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self, model=self.model, organization=self.organization, prefix=self.prefix, suffix=self.suffix
self,
model=self.model,
organization=self.organization,
prefix=self.prefix,
suffix=self.suffix,
api_key=self.api_key.to_dict(),
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAITextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string."""

View File

@ -1,9 +1,10 @@
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Dict, Any
from haystack import component, Document, default_to_dict
from haystack import component, Document, default_to_dict, default_from_dict
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils import Secret, deserialize_secrets_inplace
@component
@ -31,7 +32,7 @@ class SentenceTransformersDocumentEmbedder:
self,
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
@ -48,8 +49,6 @@ class SentenceTransformersDocumentEmbedder:
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
:param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and bge.
@ -87,7 +86,7 @@ class SentenceTransformersDocumentEmbedder:
self,
model=self.model,
device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
token=self.token.to_dict() if self.token else None,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
@ -97,13 +96,18 @@ class SentenceTransformersDocumentEmbedder:
embedding_separator=self.embedding_separator,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device, use_auth_token=self.token
model=self.model, device=self.device, auth_token=self.token
)
@component.output_types(documents=List[Document])

View File

@ -1,9 +1,10 @@
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Dict, Any
from haystack import component, default_to_dict
from haystack import component, default_to_dict, default_from_dict
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils import Secret, deserialize_secrets_inplace
@component
@ -30,7 +31,7 @@ class SentenceTransformersTextEmbedder:
self,
model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None,
token: Union[bool, str, None] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
@ -45,8 +46,6 @@ class SentenceTransformersTextEmbedder:
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU.
:param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and bge.
@ -80,7 +79,7 @@ class SentenceTransformersTextEmbedder:
self,
model=self.model,
device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
token=self.token.to_dict() if self.token else None,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
@ -88,13 +87,18 @@ class SentenceTransformersTextEmbedder:
normalize_embeddings=self.normalize_embeddings,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def warm_up(self):
"""
Load the embedding backend.
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device, use_auth_token=self.token
model=self.model, device=self.device, auth_token=self.token
)
@component.output_types(embedding=List[float])

View File

@ -3,12 +3,13 @@ import os
from typing import Optional, Callable, Dict, Any
# pylint: disable=import-error
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from openai.lib.azure import AzureOpenAI
from haystack import default_to_dict, default_from_dict
from haystack.components.generators import OpenAIGenerator
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -28,8 +29,9 @@ class AzureOpenAIGenerator(OpenAIGenerator):
```python
from haystack.components.generators import AzureOpenAIGenerator
from haystack.utils import Secret
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
api_key="<you api key>",
api_key=Secret.from_token("<your-api-key>"),
azure_deployment="<this a model name, e.g. gpt-35-turbo>")
response = client.run("What's Natural Language Processing? Be brief.")
print(response)
@ -56,9 +58,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: Optional[str] = "gpt-35-turbo",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
system_prompt: Optional[str] = None,
@ -70,8 +71,6 @@ class AzureOpenAIGenerator(OpenAIGenerator):
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param streaming_callback: A callback function that is called when a new token is received from the stream.
@ -108,6 +107,13 @@ class AzureOpenAIGenerator(OpenAIGenerator):
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
# The check above makes mypy incorrectly infer that api_key is never None,
# which propagates the incorrect type.
self.api_key = api_key # type: ignore
self.azure_ad_token = azure_ad_token
self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt
self.streaming_callback = streaming_callback
@ -121,9 +127,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
)
@ -142,6 +147,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
streaming_callback=callback_name,
generation_kwargs=self.generation_kwargs,
system_prompt=self.system_prompt,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
)
@classmethod
@ -151,6 +158,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:

View File

@ -3,12 +3,13 @@ import os
from typing import Optional, Callable, Dict, Any
# pylint: disable=import-error
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
from openai.lib.azure import AzureOpenAI
from haystack import default_to_dict, default_from_dict
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -28,11 +29,12 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
```python
from haystack.components.generators.chat import AzureOpenAIGenerator
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
api_key="<you api key>",
api_key=Secret.from_token("<your-api-key>"),
azure_deployment="<this a model name, e.g. gpt-35-turbo>")
response = client.run(messages)
print(response)
@ -63,9 +65,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
azure_deployment: Optional[str] = "gpt-35-turbo",
api_key: Optional[str] = None,
azure_ad_token: Optional[str] = None,
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
@ -76,8 +77,6 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
:param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param streaming_callback: A callback function that is called when a new token is received from the stream.
@ -113,6 +112,13 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
# The check above makes mypy incorrectly infer that api_key is never None,
# which propagates the incorrect type.
self.api_key = api_key # type: ignore
self.azure_ad_token = azure_ad_token
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.api_version = api_version
@ -125,9 +131,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
api_version=api_version,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_key=api_key,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
organization=organization,
)
@ -145,6 +150,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
api_version=self.api_version,
streaming_callback=callback_name,
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
)
@classmethod
@ -154,6 +161,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:

View File

@ -9,6 +9,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -57,7 +58,7 @@ class HuggingFaceLocalChatGenerator:
model: str = "HuggingFaceH4/zephyr-7b-beta",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
chat_template: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
@ -80,7 +81,6 @@ class HuggingFaceLocalChatGenerator:
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
messages. While high-quality and well-supported chat models typically include their own chat templates
@ -113,6 +113,9 @@ class HuggingFaceLocalChatGenerator:
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {}
self.token = token
token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model)
@ -178,12 +181,11 @@ class HuggingFaceLocalChatGenerator:
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
)
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
# we don't want to serialize valid tokens
if isinstance(huggingface_pipeline_kwargs["token"], str):
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
huggingface_pipeline_kwargs.pop("token", None)
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict
@ -194,7 +196,7 @@ class HuggingFaceLocalChatGenerator:
Deserialize this component from a dictionary.
"""
torch_and_transformers_import.check() # leave this, cls method
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:

View File

@ -8,6 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.components.generators.hf_utils import check_valid_model, check_generation_params
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -28,12 +29,13 @@ class HuggingFaceTGIChatGenerator:
```python
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="<your-token>")
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
print(response)
@ -51,7 +53,7 @@ class HuggingFaceTGIChatGenerator:
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf",
url="<your-tgi-endpoint-url>",
token="<your-token>")
token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
print(response)
@ -85,7 +87,7 @@ class HuggingFaceTGIChatGenerator:
self,
model: str = "meta-llama/Llama-2-13b-chat-hf",
url: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
chat_template: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
@ -131,7 +133,7 @@ class HuggingFaceTGIChatGenerator:
self.chat_template = chat_template
self.token = token
self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token)
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.streaming_callback = streaming_callback
self.tokenizer = None
@ -139,7 +141,9 @@ class HuggingFaceTGIChatGenerator:
"""
Load the tokenizer.
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
)
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
chat_template = getattr(self.tokenizer, "chat_template", None)
if not chat_template and not self.chat_template:
@ -162,7 +166,7 @@ class HuggingFaceTGIChatGenerator:
model=self.model,
url=self.url,
chat_template=self.chat_template,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)
@ -172,6 +176,7 @@ class HuggingFaceTGIChatGenerator:
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:

View File

@ -13,6 +13,7 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk, ChatMessage
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -62,7 +63,7 @@ class OpenAIChatGenerator:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
@ -73,8 +74,7 @@ class OpenAIChatGenerator:
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
@ -101,12 +101,13 @@ class OpenAIChatGenerator:
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
self.api_key = api_key
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.organization = organization
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -127,6 +128,7 @@ class OpenAIChatGenerator:
api_base_url=self.api_base_url,
organization=self.organization,
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
)
@classmethod
@ -136,6 +138,7 @@ class OpenAIChatGenerator:
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
@ -334,7 +337,7 @@ class OpenAIChatGenerator:
class GPTChatGenerator(OpenAIChatGenerator):
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union, Callable
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient, HfApi
@ -36,20 +37,20 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
)
def check_valid_model(model_id: str, token: Optional[str]) -> None:
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
"""
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a text generation model.
:param model_id: A string representing the HuggingFace model ID.
:param token: An optional string representing the authentication token.
:param token: An optional authentication token.
:raises ValueError: If the model is not found or is not a text generation model.
"""
transformers_import.check()
api = HfApi()
try:
model_info = api.model_info(model_id, token=token)
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e:
raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."

View File

@ -1,11 +1,12 @@
import logging
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional
from haystack import component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -44,7 +45,7 @@ class HuggingFaceLocalGenerator:
model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
@ -63,7 +64,6 @@ class HuggingFaceLocalGenerator:
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
@ -89,6 +89,9 @@ class HuggingFaceLocalGenerator:
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {}
self.token = token
token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model)
@ -156,12 +159,11 @@ class HuggingFaceLocalGenerator:
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs,
stop_words=self.stop_words,
token=self.token.to_dict() if self.token else None,
)
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
# we don't want to serialize valid tokens
if isinstance(huggingface_pipeline_kwargs["token"], str):
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
huggingface_pipeline_kwargs.pop("token", None)
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict
@ -171,6 +173,7 @@ class HuggingFaceLocalGenerator:
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
return default_from_dict(cls, data)

View File

@ -8,6 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import StreamingChunk
from haystack.components.generators.hf_utils import check_generation_params, check_valid_model
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient
@ -29,7 +30,9 @@ class HuggingFaceTGIGenerator:
```python
from haystack.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="<your-token>")
from haystack.utils import Secret
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
print(response)
@ -42,7 +45,7 @@ class HuggingFaceTGIGenerator:
from haystack.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1",
url="<your-tgi-endpoint-url>",
token="<your-token>")
token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run("What's Natural Language Processing?")
print(response)
@ -74,7 +77,7 @@ class HuggingFaceTGIGenerator:
self,
model: str = "mistralai/Mistral-7B-v0.1",
url: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
@ -113,7 +116,7 @@ class HuggingFaceTGIGenerator:
self.url = url
self.token = token
self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token)
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.streaming_callback = streaming_callback
self.tokenizer = None
@ -121,7 +124,9 @@ class HuggingFaceTGIGenerator:
"""
Load the tokenizer
"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
)
def to_dict(self) -> Dict[str, Any]:
"""
@ -134,7 +139,7 @@ class HuggingFaceTGIGenerator:
self,
model=self.model,
url=self.url,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name,
)
@ -144,6 +149,7 @@ class HuggingFaceTGIGenerator:
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:

View File

@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletionChunk, ChatCompletion
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk, ChatMessage
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -50,7 +51,7 @@ class OpenAIGenerator:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
@ -62,8 +63,7 @@ class OpenAIGenerator:
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
environment variable OPENAI_API_KEY (recommended).
:param api_key: The OpenAI API key.
:param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
@ -92,6 +92,7 @@ class OpenAIGenerator:
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token.
"""
self.api_key = api_key
self.model = model
self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt
@ -99,7 +100,7 @@ class OpenAIGenerator:
self.api_base_url = api_base_url
self.organization = organization
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]:
"""
@ -120,6 +121,7 @@ class OpenAIGenerator:
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
system_prompt=self.system_prompt,
api_key=self.api_key.to_dict(),
)
@classmethod
@ -129,6 +131,7 @@ class OpenAIGenerator:
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
@ -279,7 +282,7 @@ class OpenAIGenerator:
class GPTGenerator(OpenAIGenerator):
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,

View File

@ -6,6 +6,7 @@ from haystack import ComponentError, Document, component, default_from_dict, def
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, DeviceMap
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs, resolve_hf_device_map
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -40,7 +41,7 @@ class TransformersSimilarityRanker:
self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[ComponentDevice] = None,
token: Union[bool, str, None] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
top_k: int = 10,
query_prefix: str = "",
document_prefix: str = "",
@ -119,9 +120,11 @@ class TransformersSimilarityRanker:
"""
if self.model is None:
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name_or_path, token=self.token, **self.model_kwargs
self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
def to_dict(self) -> Dict[str, Any]:
@ -132,7 +135,7 @@ class TransformersSimilarityRanker:
self,
device=None,
model=self.model_name_or_path,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
query_prefix=self.query_prefix,
document_prefix=self.document_prefix,
@ -152,6 +155,7 @@ class TransformersSimilarityRanker:
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data["init_parameters"]
if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])

View File

@ -1,11 +1,11 @@
import json
import logging
from typing import Dict, List, Optional, Any
import os
import requests
from haystack import Document, component, default_to_dict, ComponentError
from haystack import Document, component, default_to_dict, ComponentError, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -27,43 +27,48 @@ class SearchApiWebSearch:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("SEARCHAPI_API_KEY"),
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the SearchApi API. It can be
explicitly provided or automatically read from the
environment variable SEARCHAPI_API_KEY (recommended).
:param api_key: API key for the SearchApi API
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SearchApi API.
For example, you can set 'num' to 100 to increase the number of search results.
See the [SearchApi website](https://www.searchapi.io/) for more details.
"""
api_key = api_key or os.environ.get("SEARCHAPI_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"SearchApiWebSearch expects an API key. "
"Set the SEARCHAPI_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.api_key = api_key
self.top_k = top_k
self.allowed_domains = allowed_domains
self.search_params = search_params or {}
# Ensure that the API key is resolved.
_ = self.api_key.resolve_value()
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
self,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
api_key=self.api_key.to_dict(),
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SearchApiWebSearch":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str])
def run(self, query: str):
"""
@ -74,7 +79,7 @@ class SearchApiWebSearch:
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
payload = json.dumps({"q": query_prepend + " " + query, **self.search_params})
headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "Haystack"}
headers = {"Authorization": f"Bearer {self.api_key.resolve_value()}", "X-SearchApi-Source": "Haystack"}
try:
response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90)

View File

@ -1,11 +1,11 @@
import json
import logging
from typing import Dict, List, Optional, Any
import os
import requests
from haystack import Document, component, default_to_dict, ComponentError
from haystack import Document, component, default_to_dict, ComponentError, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
@ -27,43 +27,47 @@ class SerperDevWebSearch:
def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("SERPERDEV_API_KEY"),
top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None,
):
"""
:param api_key: API key for the SerperDev API. It can be
explicitly provided or automatically read from the
environment variable SERPERDEV_API_KEY (recommended).
:param api_key: API key for the SerperDev API.
:param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SerperDev API.
For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details.
"""
api_key = api_key or os.environ.get("SERPERDEV_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.api_key = api_key
self.top_k = top_k
self.allowed_domains = allowed_domains
self.search_params = search_params or {}
# Ensure that the API key is resolved.
_ = self.api_key.resolve_value()
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
self,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
api_key=self.api_key.to_dict(),
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str])
def run(self, query: str):
"""
@ -76,10 +80,10 @@ class SerperDevWebSearch:
payload = json.dumps(
{"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **self.search_params}
)
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
headers = {"X-API-KEY": self.api_key.resolve_value(), "Content-Type": "application/json"}
try:
response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30)
response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) # type: ignore
response.raise_for_status() # Will raise an HTTPError for bad responses
except requests.Timeout:
raise TimeoutError(f"Request to {self.__class__.__name__} timed out.")

View File

@ -2,4 +2,4 @@ from haystack.utils.expit import expit
from haystack.utils.requests_utils import request_with_retry
from haystack.utils.filters import document_matches_filter
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap
from haystack.utils.auth import Secret
from haystack.utils.auth import Secret, deserialize_secrets_inplace

View File

@ -1,6 +1,6 @@
from enum import Enum
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union
from dataclasses import dataclass
from abc import ABC, abstractmethod
@ -193,3 +193,21 @@ class EnvVarSecret(Secret):
if out is None and self._strict:
raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}")
return out
def deserialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False):
"""
Deserialize secrets in a dictionary inplace.
:param data:
The dictionary with the serialized data.
:param keys:
The keys of the secrets to deserialize.
:param recursive:
Whether to recursively deserialize nested dictionaries.
"""
for k, v in data.items():
if isinstance(v, dict) and recursive:
deserialize_secrets_inplace(v, keys)
elif k in keys and v is not None:
data[k] = Secret.from_dict(v)

View File

@ -4,3 +4,37 @@ features:
Expose a `Secret` type to provide consistent API for any component that requires secrets for authentication.
Currently supports string tokens and environment variables. Token-based secrets are automatically
prevented from being serialized to disk (to prevent accidental leakage of secrets).
```python
@component
class MyComponent:
def __init__(self, api_key: Optional[Secret] = None, **kwargs):
self.api_key = api_key
self.backend = None
def warm_up(self):
# Call resolve_value to yield a single result. The semantics of the result is policy-dependent.
# Currently, all supported policies will return a single string token.
self.backend = SomeBackend(api_key=self.api_key.resolve_value() if self.api_key else None, ...)
def to_dict(self):
# Serialize the policy like any other (custom) data. If the policy is token-based, it will
# raise an error.
return default_to_dict(self, api_key=self.api_key.to_dict() if self.api_key else None, ...)
@classmethod
def from_dict(cls, data):
# Deserialize the policy data before passing it to the generic from_dict function.
api_key_data = data["init_parameters"]["api_key"]
api_key = Secret.from_dict(api_key_data) if api_key_data is not None else None
data["init_parameters"]["api_key"] = api_key
return default_from_dict(cls, data)
# No authentication.
component = MyComponent(api_key=None)
# Token based authentication
component = MyComponent(api_key=Secret.from_token("sk-randomAPIkeyasdsa32ekasd32e"))
component.to_dict() # Error! Can't serialize authentication tokens
# Environment variable based authentication
component = MyComponent(api_key=Secret.from_env("OPENAI_API_KEY"))
component.to_dict() # This is fine
```

View File

@ -0,0 +1,8 @@
---
upgrade:
- |
Update secret handling for components using the `Secret` type. The following components are affected:
`RemoteWhisperTranscriber`, `AzureOCRDocumentConverter`, `AzureOpenAIDocumentEmbedder`, `AzureOpenAITextEmbedder`, `HuggingFaceTEIDocumentEmbedder`, `HuggingFaceTEITextEmbedder`, `OpenAIDocumentEmbedder`, `SentenceTransformersDocumentEmbedder`, `SentenceTransformersTextEmbedder`, `AzureOpenAIGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceLocalChatGenerator`, `HuggingFaceTGIChatGenerator`, `OpenAIChatGenerator`, `HuggingFaceLocalGenerator`, `HuggingFaceTGIGenerator`, `OpenAIGenerator`, `TransformersSimilarityRanker`, `SearchApiWebSearch`, `SerperDevWebSearch`
The default init parameters for `api_key`, `token`, `azure_ad_token` have been adjusted to use environment variables wherever possible. The `azure_ad_token_provider` parameter has been removed from Azure-based components. Components based on Hugging
Face are now required to either use a token or an environment variable if authentication is required - The on-disk local token file is no longer supported.

View File

@ -1,29 +1,29 @@
import os
import pytest
from openai import OpenAIError
from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.dataclasses import ByteStream
from haystack.utils import Secret
class TestRemoteWhisperTranscriber:
def test_init_no_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
RemoteWhisperTranscriber(api_key=None)
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
RemoteWhisperTranscriber()
def test_init_key_env_var(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
t = RemoteWhisperTranscriber(api_key=None)
t = RemoteWhisperTranscriber()
assert t.client.api_key == "test_api_key"
def test_init_key_module_env_and_global_var(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2")
t = RemoteWhisperTranscriber(api_key=None)
t = RemoteWhisperTranscriber()
assert t.client.api_key == "test_api_key_2"
def test_init_default(self):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
transcriber = RemoteWhisperTranscriber(api_key=Secret.from_token("test_api_key"))
assert transcriber.client.api_key == "test_api_key"
assert transcriber.model == "whisper-1"
assert transcriber.organization is None
@ -31,7 +31,7 @@ class TestRemoteWhisperTranscriber:
def test_init_custom_parameters(self):
transcriber = RemoteWhisperTranscriber(
api_key="test_api_key",
api_key=Secret.from_token("test_api_key"),
model="whisper-1",
organization="test-org",
api_base_url="test_api_url",
@ -42,6 +42,7 @@ class TestRemoteWhisperTranscriber:
)
assert transcriber.model == "whisper-1"
assert transcriber.api_key == Secret.from_token("test_api_key")
assert transcriber.organization == "test-org"
assert transcriber.api_base_url == "test_api_url"
assert transcriber.whisper_params == {
@ -51,12 +52,14 @@ class TestRemoteWhisperTranscriber:
"temperature": "0.5",
}
def test_to_dict_default_parameters(self):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
def test_to_dict_default_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
transcriber = RemoteWhisperTranscriber()
data = transcriber.to_dict()
assert data == {
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "whisper-1",
"api_base_url": None,
"organization": None,
@ -64,9 +67,10 @@ class TestRemoteWhisperTranscriber:
},
}
def test_to_dict_with_custom_init_parameters(self):
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
transcriber = RemoteWhisperTranscriber(
api_key="test_api_key",
api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="whisper-1",
organization="test-org",
api_base_url="test_api_url",
@ -79,6 +83,7 @@ class TestRemoteWhisperTranscriber:
assert data == {
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "whisper-1",
"organization": "test-org",
"api_base_url": "test_api_url",
@ -142,6 +147,7 @@ class TestRemoteWhisperTranscriber:
data = {
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "whisper-1",
"api_base_url": "https://api.openai.com/v1",
"organization": None,
@ -149,7 +155,7 @@ class TestRemoteWhisperTranscriber:
},
}
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
RemoteWhisperTranscriber.from_dict(data)
@pytest.mark.skipif(

View File

@ -5,6 +5,7 @@ import pytest
from haystack.components.converters.azure import AzureOCRDocumentConverter
from haystack.dataclasses import ByteStream
from haystack.utils import Secret
class TestAzureOCRDocumentConverter:
@ -13,15 +14,23 @@ class TestAzureOCRDocumentConverter:
with pytest.raises(ValueError):
AzureOCRDocumentConverter(endpoint="test_endpoint")
def test_to_dict(self):
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_to_dict(self, mock_resolve_value):
mock_resolve_value.return_value = "test_api_key"
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
data = component.to_dict()
assert data == {
"type": "haystack.components.converters.azure.AzureOCRDocumentConverter",
"init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"},
"init_parameters": {
"api_key": {"env_vars": ["AZURE_AI_API_KEY"], "strict": True, "type": "env_var"},
"endpoint": "test_endpoint",
"model_id": "prebuilt-read",
},
}
def test_run(self, test_files_path):
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_run(self, mock_resolve_value, test_files_path):
mock_resolve_value.return_value = "test_api_key"
with patch("haystack.components.converters.azure.DocumentAnalysisClient") as mock_azure_client:
mock_result = Mock(pages=[Mock(lines=[Mock(content="mocked line 1"), Mock(content="mocked line 2")])])
mock_result.to_dict.return_value = {
@ -32,7 +41,7 @@ class TestAzureOCRDocumentConverter:
}
mock_azure_client.return_value.begin_analyze_document.return_value.result.return_value = mock_result
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
document = output["documents"][0]
assert document.content == "mocked line 1\nmocked line 2\n\f"
@ -44,11 +53,12 @@ class TestAzureOCRDocumentConverter:
"pages": [{"lines": [{"content": "mocked line 1"}, {"content": "mocked line 2"}]}],
}
def test_run_with_meta(self, test_files_path):
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_run_with_meta(self, mock_resolve_value, test_files_path):
mock_resolve_value.return_value = "test_api_key"
bytestream = ByteStream(data=b"test", meta={"author": "test_author", "language": "en"})
with patch("haystack.components.converters.azure.DocumentAnalysisClient"):
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
output = component.run(
sources=[bytestream, test_files_path / "pdf" / "sample_pdf_1.pdf"], meta={"language": "it"}
)
@ -63,7 +73,7 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_run_with_pdf_file(self, test_files_path):
component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"]
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
)
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
documents = output["documents"]
@ -77,7 +87,7 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_with_image_file(self, test_files_path):
component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"]
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
)
output = component.run(sources=[test_files_path / "images" / "haystack-logo.png"])
documents = output["documents"]
@ -90,7 +100,7 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_run_with_docx_file(self, test_files_path):
component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"]
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
)
output = component.run(sources=[test_files_path / "docx" / "sample_docx.docx"])
documents = output["documents"]

View File

@ -19,14 +19,15 @@ class TestAzureOpenAIDocumentEmbedder:
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
def test_to_dict(self):
component = AzureOpenAIDocumentEmbedder(
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
)
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
component = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002",
"azure_endpoint": "https://example-resource.azure.openai.com/",

View File

@ -16,14 +16,15 @@ class TestAzureOpenAITextEmbedder:
assert embedder.prefix == ""
assert embedder.suffix == ""
def test_to_dict(self):
component = AzureOpenAITextEmbedder(
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
)
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
component = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "text-embedding-ada-002",
"organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/",

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.utils.auth import Secret
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
from haystack.dataclasses import Document
@ -29,7 +30,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None
assert embedder.token == "fake-api-token"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
@ -41,7 +42,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com",
token="fake-api-token",
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
batch_size=64,
@ -52,7 +53,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == "fake-api-token"
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
@ -78,6 +79,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": {
"model": "BAAI/bge-small-en-v1.5",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"url": None,
"prefix": "",
"suffix": "",
@ -92,7 +94,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
component = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com",
token="fake-api-token",
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
batch_size=64,
@ -106,6 +108,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert data == {
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com",
"prefix": "prefix",
@ -125,7 +128,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com",
token="fake-api-token",
token=Secret.from_token("fake-api-token"),
meta_fields_to_embed=["meta_field"],
embedding_separator=" | ",
)
@ -146,7 +149,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com",
token="fake-api-token",
token=Secret.from_token("fake-api-token"),
prefix="my_prefix ",
suffix=" my_suffix",
)
@ -168,7 +171,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
)
embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
@ -192,7 +197,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
token="fake-api-token",
token=Secret.from_token("fake-api-token"),
prefix="prefix ",
suffix=" suffix",
meta_fields_to_embed=["topic"],
@ -228,7 +233,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5",
token="fake-api-token",
token=Secret.from_token("fake-api-token"),
prefix="prefix ",
suffix=" suffix",
meta_fields_to_embed=["topic"],
@ -252,7 +257,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
def test_run_wrong_input_format(self, mock_check_valid_model):
embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
)
# wrong formats
@ -267,7 +274,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
def test_run_on_empty_list(self, mock_check_valid_model):
embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
)
empty_list_input = []

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from huggingface_hub.utils import RepositoryNotFoundError
from haystack.utils.auth import Secret
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
@ -27,7 +28,7 @@ class TestHuggingFaceTEITextEmbedder:
assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None
assert embedder.token == "fake-api-token"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
@ -35,14 +36,14 @@ class TestHuggingFaceTEITextEmbedder:
embedder = HuggingFaceTEITextEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com",
token="fake-api-token",
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
)
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == "fake-api-token"
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
@ -62,14 +63,20 @@ class TestHuggingFaceTEITextEmbedder:
assert data == {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {"model": "BAAI/bge-small-en-v1.5", "url": None, "prefix": "", "suffix": ""},
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "BAAI/bge-small-en-v1.5",
"url": None,
"prefix": "",
"suffix": "",
},
}
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
component = HuggingFaceTEITextEmbedder(
model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com",
token="fake-api-token",
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
)
@ -79,6 +86,7 @@ class TestHuggingFaceTEITextEmbedder:
assert data == {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com",
"prefix": "prefix",
@ -91,7 +99,10 @@ class TestHuggingFaceTEITextEmbedder:
mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", token="fake-api-token", prefix="prefix ", suffix=" suffix"
model="BAAI/bge-small-en-v1.5",
token=Secret.from_token("fake-api-token"),
prefix="prefix ",
suffix=" suffix",
)
result = embedder.run(text="The food was delicious")
@ -103,7 +114,9 @@ class TestHuggingFaceTEITextEmbedder:
def test_run_wrong_input_format(self, mock_check_valid_model):
embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
)
list_integers_input = [1, 2, 3]

View File

@ -1,9 +1,9 @@
import os
from typing import List
from haystack.utils.auth import Secret
import numpy as np
import pytest
from openai import OpenAIError
from haystack import Document
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
@ -37,7 +37,7 @@ class TestOpenAIDocumentEmbedder:
def test_init_with_parameters(self):
embedder = OpenAIDocumentEmbedder(
api_key="fake-api-key",
api_key=Secret.from_token("fake-api-key"),
model="model",
organization="my-org",
prefix="prefix",
@ -58,15 +58,17 @@ class TestOpenAIDocumentEmbedder:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIDocumentEmbedder()
def test_to_dict(self):
component = OpenAIDocumentEmbedder(api_key="fake-api-key")
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
component = OpenAIDocumentEmbedder()
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"api_base_url": None,
"model": "text-embedding-ada-002",
"organization": None,
@ -79,9 +81,10 @@ class TestOpenAIDocumentEmbedder:
},
}
def test_to_dict_with_custom_init_parameters(self):
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "fake-api-key")
component = OpenAIDocumentEmbedder(
api_key="fake-api-key",
api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="model",
organization="my-org",
prefix="prefix",
@ -95,6 +98,7 @@ class TestOpenAIDocumentEmbedder:
assert data == {
"type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"api_base_url": None,
"model": "model",
"organization": "my-org",
@ -113,7 +117,7 @@ class TestOpenAIDocumentEmbedder:
]
embedder = OpenAIDocumentEmbedder(
api_key="fake-api-key", meta_fields_to_embed=["meta_field"], embedding_separator=" | "
api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | "
)
prepared_texts = embedder._prepare_texts_to_embed(documents)
@ -130,7 +134,9 @@ class TestOpenAIDocumentEmbedder:
def test_prepare_texts_to_embed_w_suffix(self):
documents = [Document(content=f"document number {i}") for i in range(5)]
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix")
embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix"
)
prepared_texts = embedder._prepare_texts_to_embed(documents)
@ -143,7 +149,7 @@ class TestOpenAIDocumentEmbedder:
]
def test_run_wrong_input_format(self):
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key")
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
# wrong formats
string_input = "text"
@ -156,7 +162,7 @@ class TestOpenAIDocumentEmbedder:
embedder.run(documents=list_integers_input)
def test_run_on_empty_list(self):
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key")
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
empty_list_input = []
result = embedder.run(documents=empty_list_input)

View File

@ -1,7 +1,7 @@
import os
from haystack.utils.auth import Secret
import pytest
from openai import OpenAIError
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
@ -19,7 +19,11 @@ class TestOpenAITextEmbedder:
def test_init_with_parameters(self):
embedder = OpenAITextEmbedder(
api_key="fake-api-key", model="model", organization="fake-organization", prefix="prefix", suffix="suffix"
api_key=Secret.from_token("fake-api-key"),
model="model",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
)
assert embedder.client.api_key == "fake-api-key"
assert embedder.model == "model"
@ -29,25 +33,38 @@ class TestOpenAITextEmbedder:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAITextEmbedder()
def test_to_dict(self):
component = OpenAITextEmbedder(api_key="fake-api-key")
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
component = OpenAITextEmbedder()
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
"init_parameters": {"model": "text-embedding-ada-002", "organization": None, "prefix": "", "suffix": ""},
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "text-embedding-ada-002",
"organization": None,
"prefix": "",
"suffix": "",
},
}
def test_to_dict_with_custom_init_parameters(self):
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "fake-api-key")
component = OpenAITextEmbedder(
api_key="fake-api-key", model="model", organization="fake-organization", prefix="prefix", suffix="suffix"
api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="model",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
)
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model",
"organization": "fake-organization",
"prefix": "prefix",
@ -56,7 +73,7 @@ class TestOpenAITextEmbedder:
}
def test_run_wrong_input_format(self):
embedder = OpenAITextEmbedder(api_key="fake-api-key")
embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key"))
list_integers_input = [1, 2, 3]

View File

@ -1,6 +1,7 @@
from unittest.mock import patch, MagicMock
import pytest
import numpy as np
from haystack.utils.auth import Secret
from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
@ -11,7 +12,7 @@ class TestSentenceTransformersDocumentEmbedder:
embedder = SentenceTransformersDocumentEmbedder(model="model")
assert embedder.model == "model"
assert embedder.device == "cpu"
assert embedder.token is None
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
@ -24,7 +25,7 @@ class TestSentenceTransformersDocumentEmbedder:
embedder = SentenceTransformersDocumentEmbedder(
model="model",
device="cuda",
token=True,
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
batch_size=64,
@ -35,7 +36,7 @@ class TestSentenceTransformersDocumentEmbedder:
)
assert embedder.model == "model"
assert embedder.device == "cuda"
assert embedder.token is True
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
@ -52,7 +53,7 @@ class TestSentenceTransformersDocumentEmbedder:
"init_parameters": {
"model": "model",
"device": "cpu",
"token": None,
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "",
"suffix": "",
"batch_size": 32,
@ -67,7 +68,7 @@ class TestSentenceTransformersDocumentEmbedder:
component = SentenceTransformersDocumentEmbedder(
model="model",
device="cuda",
token="the-token",
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
batch_size=64,
@ -83,7 +84,7 @@ class TestSentenceTransformersDocumentEmbedder:
"init_parameters": {
"model": "model",
"device": "cuda",
"token": None, # the token is not serialized
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
@ -98,10 +99,10 @@ class TestSentenceTransformersDocumentEmbedder:
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model="model")
embedder = SentenceTransformersDocumentEmbedder(model="model", token=None)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", use_auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -3,6 +3,7 @@ import pytest
from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory,
)
from haystack.utils.auth import Secret
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
@ -22,10 +23,10 @@ def test_factory_behavior(mock_sentence_transformer):
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", use_auth_token="my_token"
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="my_token"
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
)

View File

@ -1,5 +1,6 @@
from unittest.mock import patch, MagicMock
import pytest
from haystack.utils.auth import Secret
import numpy as np
@ -11,7 +12,7 @@ class TestSentenceTransformersTextEmbedder:
embedder = SentenceTransformersTextEmbedder(model="model")
assert embedder.model == "model"
assert embedder.device == "cpu"
assert embedder.token is None
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == ""
assert embedder.suffix == ""
assert embedder.batch_size == 32
@ -22,7 +23,7 @@ class TestSentenceTransformersTextEmbedder:
embedder = SentenceTransformersTextEmbedder(
model="model",
device="cuda",
token=True,
token=Secret.from_token("fake-api-token"),
prefix="prefix",
suffix="suffix",
batch_size=64,
@ -31,7 +32,7 @@ class TestSentenceTransformersTextEmbedder:
)
assert embedder.model == "model"
assert embedder.device == "cuda"
assert embedder.token is True
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
assert embedder.batch_size == 64
@ -44,9 +45,9 @@ class TestSentenceTransformersTextEmbedder:
assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "model",
"device": "cpu",
"token": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
@ -59,7 +60,7 @@ class TestSentenceTransformersTextEmbedder:
component = SentenceTransformersTextEmbedder(
model="model",
device="cuda",
token=True,
token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix",
suffix="suffix",
batch_size=64,
@ -70,9 +71,9 @@ class TestSentenceTransformersTextEmbedder:
assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model",
"device": "cuda",
"token": True,
"prefix": "prefix",
"suffix": "suffix",
"batch_size": 64,
@ -82,30 +83,18 @@ class TestSentenceTransformersTextEmbedder:
}
def test_to_dict_not_serialize_token(self):
component = SentenceTransformersTextEmbedder(model="model", token="awesome-token")
data = component.to_dict()
assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"model": "model",
"device": "cpu",
"token": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
},
}
component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))
with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
component.to_dict()
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
)
def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model")
embedder = SentenceTransformersTextEmbedder(model="model", token=None)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", use_auth_token=None)
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -6,11 +6,13 @@ from openai import OpenAIError
from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage
from haystack.utils.auth import Secret
class TestOpenAIChatGenerator:
def test_init_default(self):
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", api_key="test-api-key")
def test_init_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
assert component.client.api_key == "test-api-key"
assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is None
@ -18,13 +20,14 @@ class TestOpenAIChatGenerator:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False)
with pytest.raises(OpenAIError):
AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
def test_init_with_parameters(self):
component = AzureOpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"),
azure_endpoint="some-non-existing-endpoint",
api_key="test-api-key",
streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
@ -33,12 +36,15 @@ class TestOpenAIChatGenerator:
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self):
component = AzureOpenAIChatGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_endpoint": "some-non-existing-endpoint",
"azure_deployment": "gpt-35-turbo",
@ -48,9 +54,11 @@ class TestOpenAIChatGenerator:
},
}
def test_to_dict_with_parameters(self):
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = AzureOpenAIChatGenerator(
api_key="test-api-key",
api_key=Secret.from_env_var("ENV_VAR", strict=False),
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
@ -58,6 +66,8 @@ class TestOpenAIChatGenerator:
assert data == {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15",
"azure_endpoint": "some-non-existing-endpoint",
"azure_deployment": "gpt-35-turbo",

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, Mock
from haystack.utils.auth import Secret
import pytest
from transformers import PreTrainedTokenizer
@ -56,7 +57,7 @@ class TestHuggingFaceLocalChatGenerator:
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
token="test-token",
token=Secret.from_token("test-token"),
device=ComponentDevice.from_str("cpu"),
)
@ -72,6 +73,7 @@ class TestHuggingFaceLocalChatGenerator:
model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation",
device=ComponentDevice.from_str("cpu"),
token=None,
)
assert generator.huggingface_pipeline_kwargs == {
@ -82,7 +84,9 @@ class TestHuggingFaceLocalChatGenerator:
}
def test_init_task_parameter(self):
generator = HuggingFaceLocalChatGenerator(task="text2text-generation", device=ComponentDevice.from_str("cpu"))
generator = HuggingFaceLocalChatGenerator(
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
)
assert generator.huggingface_pipeline_kwargs == {
"model": "HuggingFaceH4/zephyr-7b-beta",
@ -93,7 +97,9 @@ class TestHuggingFaceLocalChatGenerator:
def test_init_task_in_huggingface_pipeline_kwargs(self):
generator = HuggingFaceLocalChatGenerator(
huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu")
huggingface_pipeline_kwargs={"task": "text2text-generation"},
device=ComponentDevice.from_str("cpu"),
token=None,
)
assert generator.huggingface_pipeline_kwargs == {
@ -105,7 +111,7 @@ class TestHuggingFaceLocalChatGenerator:
def test_init_task_inferred_from_model_name(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu")
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu"), token=None
)
assert generator.huggingface_pipeline_kwargs == {
@ -122,7 +128,7 @@ class TestHuggingFaceLocalChatGenerator:
def test_to_dict(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
token="token",
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
@ -133,6 +139,7 @@ class TestHuggingFaceLocalChatGenerator:
init_params = result["init_parameters"]
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert "token" not in init_params["huggingface_pipeline_kwargs"]
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
@ -59,7 +60,7 @@ class TestHuggingFaceTGIChatGenerator:
# Initialize the HuggingFaceTGIChatGenerator object with valid parameters
generator = HuggingFaceTGIChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf",
token="token",
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
@ -71,7 +72,7 @@ class TestHuggingFaceTGIChatGenerator:
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] is None
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
def test_from_dict(self, mock_check_valid_model):

View File

@ -2,6 +2,7 @@ import os
import pytest
from openai import OpenAIError
from haystack.utils.auth import Secret
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk
@ -17,8 +18,9 @@ def chat_messages():
class TestOpenAIChatGenerator:
def test_init_default(self):
component = OpenAIChatGenerator(api_key="test-api-key")
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator()
assert component.client.api_key == "test-api-key"
assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None
@ -26,12 +28,12 @@ class TestOpenAIChatGenerator:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIChatGenerator()
def test_init_with_parameters(self):
component = OpenAIChatGenerator(
api_key="test-api-key",
api_key=Secret.from_token("test-api-key"),
model="gpt-4",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
@ -42,12 +44,14 @@ class TestOpenAIChatGenerator:
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self):
component = OpenAIChatGenerator(api_key="test-api-key")
def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator()
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-3.5-turbo",
"organization": None,
"streaming_callback": None,
@ -56,9 +60,10 @@ class TestOpenAIChatGenerator:
},
}
def test_to_dict_with_parameters(self):
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = OpenAIChatGenerator(
api_key="test-api-key",
api_key=Secret.from_env_var("ENV_VAR"),
model="gpt-4",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
@ -68,6 +73,7 @@ class TestOpenAIChatGenerator:
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
@ -76,9 +82,9 @@ class TestOpenAIChatGenerator:
},
}
def test_to_dict_with_lambda_streaming_callback(self):
def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator(
api_key="test-api-key",
model="gpt-4",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
@ -88,6 +94,7 @@ class TestOpenAIChatGenerator:
assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
@ -96,10 +103,12 @@ class TestOpenAIChatGenerator:
},
}
def test_from_dict(self):
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
@ -111,12 +120,14 @@ class TestOpenAIChatGenerator:
assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"organization": None,
"api_base_url": "test-base-url",
@ -124,11 +135,11 @@ class TestOpenAIChatGenerator:
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIChatGenerator.from_dict(data)
def test_run(self, chat_messages, mock_chat_completion):
component = OpenAIChatGenerator()
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages)
# check that the component returns the correct ChatMessage response
@ -139,7 +150,9 @@ class TestOpenAIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_run_with_params(self, chat_messages, mock_chat_completion):
component = OpenAIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
component = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
)
response = component.run(chat_messages)
# check that the component calls the OpenAI API with the correct parameters
@ -161,7 +174,9 @@ class TestOpenAIChatGenerator:
nonlocal streaming_callback_called
streaming_callback_called = True
component = OpenAIChatGenerator(streaming_callback=streaming_callback)
component = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
)
response = component.run(chat_messages)
# check we called the streaming callback
@ -176,7 +191,7 @@ class TestOpenAIChatGenerator:
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
def test_check_abnormal_completions(self, caplog):
component = OpenAIChatGenerator(api_key="test-api-key")
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
messages = [
ChatMessage.from_assistant(
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
@ -208,7 +223,7 @@ class TestOpenAIChatGenerator:
@pytest.mark.integration
def test_live_run(self):
chat_messages = [ChatMessage.from_user("What's the capital of France")]
component = OpenAIChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
component = OpenAIChatGenerator(generation_kwargs={"n": 1})
results = component.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
@ -222,7 +237,7 @@ class TestOpenAIChatGenerator:
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
component = OpenAIChatGenerator(model="something-obviously-wrong")
with pytest.raises(OpenAIError):
component.run(chat_messages)

View File

@ -1,4 +1,5 @@
import os
from haystack.utils.auth import Secret
import pytest
from openai import OpenAIError
@ -8,8 +9,9 @@ from haystack.components.generators.utils import print_streaming_chunk
class TestAzureOpenAIGenerator:
def test_init_default(self):
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
def test_init_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
assert component.client.api_key == "test-api-key"
assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is None
@ -17,28 +19,32 @@ class TestAzureOpenAIGenerator:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False)
with pytest.raises(OpenAIError):
AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
def test_init_with_parameters(self):
component = AzureOpenAIGenerator(
api_key="test-api-key",
api_key=Secret.from_token("fake-api-key"),
azure_endpoint="some-non-existing-endpoint",
azure_deployment="gpt-35-turbo",
streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.client.api_key == "test-api-key"
assert component.client.api_key == "fake-api-key"
assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self):
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.azure.AzureOpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "gpt-35-turbo",
"api_version": "2023-05-15",
"streaming_callback": None,
@ -49,9 +55,11 @@ class TestAzureOpenAIGenerator:
},
}
def test_to_dict_with_parameters(self):
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = AzureOpenAIGenerator(
api_key="test-api-key",
api_key=Secret.from_env_var("ENV_VAR", strict=False),
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
@ -60,6 +68,8 @@ class TestAzureOpenAIGenerator:
assert data == {
"type": "haystack.components.generators.azure.AzureOpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"},
"azure_deployment": "gpt-35-turbo",
"api_version": "2023-05-15",
"streaming_callback": None,

View File

@ -4,14 +4,16 @@ from unittest.mock import Mock, patch
import pytest
import torch
from transformers import PreTrainedTokenizerFast
from haystack.utils.auth import Secret
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
from haystack.utils import ComponentDevice, Device
from haystack.utils import ComponentDevice
class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.model_info")
def test_init_default(self, model_info_mock):
def test_init_default(self, model_info_mock, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator()
@ -26,30 +28,33 @@ class TestHuggingFaceLocalGenerator:
def test_init_custom_token(self):
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", token="test-token"
model="google/flan-t5-base", task="text2text-generation", token=Secret.from_token("fake-api-token")
)
assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": "test-token",
"token": "fake-api-token",
"device": ComponentDevice.resolve_device(None).to_hf(),
}
def test_init_custom_device(self):
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", device=ComponentDevice.from_str("cuda:0")
model="google/flan-t5-base",
task="text2text-generation",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_token("fake-api-token"),
)
assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"token": "fake-api-token",
"device": "cuda:0",
}
def test_init_task_parameter(self):
generator = HuggingFaceLocalGenerator(task="text2text-generation")
generator = HuggingFaceLocalGenerator(task="text2text-generation", token=None)
assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base",
@ -59,7 +64,7 @@ class TestHuggingFaceLocalGenerator:
}
def test_init_task_in_huggingface_pipeline_kwargs(self):
generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"})
generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"}, token=None)
assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base",
@ -71,7 +76,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.model_info")
def test_init_task_inferred_from_model_name(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base")
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", token=None)
assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base",
@ -101,7 +106,7 @@ class TestHuggingFaceLocalGenerator:
model="google/flan-t5-base",
task="text2text-generation",
device=ComponentDevice.from_str("cpu"),
token="test-token",
token=None,
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs,
)
@ -142,10 +147,10 @@ class TestHuggingFaceLocalGenerator:
assert data == {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "google/flan-t5-base",
"task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(),
},
"generation_kwargs": {},
@ -158,7 +163,7 @@ class TestHuggingFaceLocalGenerator:
model="gpt2",
task="text-generation",
device=ComponentDevice.from_str("cuda:0"),
token="test-token",
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"max_new_tokens": 100},
stop_words=["coca", "cola"],
huggingface_pipeline_kwargs={
@ -175,6 +180,7 @@ class TestHuggingFaceLocalGenerator:
assert data == {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": {
"model": "gpt2",
"task": "text-generation",
@ -197,7 +203,7 @@ class TestHuggingFaceLocalGenerator:
model="gpt2",
task="text-generation",
device=ComponentDevice.from_str("cuda:0"),
token="test-token",
token=None,
generation_kwargs={"max_new_tokens": 100},
stop_words=["coca", "cola"],
huggingface_pipeline_kwargs={
@ -216,6 +222,7 @@ class TestHuggingFaceLocalGenerator:
assert data == {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": {
"token": None,
"huggingface_pipeline_kwargs": {
"model": "gpt2",
"task": "text-generation",
@ -239,6 +246,7 @@ class TestHuggingFaceLocalGenerator:
data = {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": {
"token": None,
"huggingface_pipeline_kwargs": {
"model": "gpt2",
"task": "text-generation",
@ -276,9 +284,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock):
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", token="test-token"
)
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", task="text2text-generation", token=None)
pipeline_mock.assert_not_called()
generator.warm_up()
@ -286,14 +292,14 @@ class TestHuggingFaceLocalGenerator:
pipeline_mock.assert_called_once_with(
model="google/flan-t5-base",
task="text2text-generation",
token="test-token",
token=None,
device=ComponentDevice.resolve_device(None).to_hf(),
)
@patch("haystack.components.generators.hugging_face_local.pipeline")
def test_warm_up_doesn_reload(self, pipeline_mock):
generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", token="test-token"
model="google/flan-t5-base", task="text2text-generation", token=Secret.from_token("fake-api-token")
)
pipeline_mock.assert_not_called()

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
@ -59,7 +60,10 @@ class TestHuggingFaceTGIGenerator:
def test_to_dict(self, mock_check_valid_model):
# Initialize the HuggingFaceRemoteGenerator object with valid parameters
generator = HuggingFaceTGIGenerator(
token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
)
# Call the to_dict method
@ -68,7 +72,7 @@ class TestHuggingFaceTGIGenerator:
# Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
assert not init_params["token"]
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
def test_from_dict(self, mock_check_valid_model):

View File

@ -1,5 +1,6 @@
import os
from typing import List
from haystack.utils.auth import Secret
import pytest
from openai import OpenAIError
@ -10,8 +11,9 @@ from haystack.dataclasses import StreamingChunk, ChatMessage
class TestOpenAIGenerator:
def test_init_default(self):
component = OpenAIGenerator(api_key="test-api-key")
def test_init_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator()
assert component.client.api_key == "test-api-key"
assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None
@ -19,12 +21,12 @@ class TestOpenAIGenerator:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIGenerator()
def test_init_with_parameters(self):
component = OpenAIGenerator(
api_key="test-api-key",
api_key=Secret.from_token("test-api-key"),
model="gpt-4",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
@ -35,12 +37,14 @@ class TestOpenAIGenerator:
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self):
component = OpenAIGenerator(api_key="test-api-key")
def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator()
data = component.to_dict()
assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-3.5-turbo",
"streaming_callback": None,
"system_prompt": None,
@ -49,9 +53,10 @@ class TestOpenAIGenerator:
},
}
def test_to_dict_with_parameters(self):
def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = OpenAIGenerator(
api_key="test-api-key",
api_key=Secret.from_env_var("ENV_VAR"),
model="gpt-4",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
@ -61,6 +66,7 @@ class TestOpenAIGenerator:
assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"system_prompt": None,
"api_base_url": "test-base-url",
@ -69,9 +75,9 @@ class TestOpenAIGenerator:
},
}
def test_to_dict_with_lambda_streaming_callback(self):
def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator(
api_key="test-api-key",
model="gpt-4",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
@ -81,6 +87,7 @@ class TestOpenAIGenerator:
assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"system_prompt": None,
"api_base_url": "test-base-url",
@ -94,6 +101,7 @@ class TestOpenAIGenerator:
data = {
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"system_prompt": None,
"api_base_url": "test-base-url",
@ -106,23 +114,25 @@ class TestOpenAIGenerator:
assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = {
"type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
with pytest.raises(OpenAIError):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIGenerator.from_dict(data)
def test_run(self, mock_chat_completion):
component = OpenAIGenerator(api_key="test-api-key")
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run("What's Natural Language Processing?")
# check that the component returns the correct ChatMessage response
@ -139,7 +149,7 @@ class TestOpenAIGenerator:
nonlocal streaming_callback_called
streaming_callback_called = True
component = OpenAIGenerator(streaming_callback=streaming_callback)
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback)
response = component.run("Come on, stream!")
# check we called the streaming callback
@ -153,7 +163,9 @@ class TestOpenAIGenerator:
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk
def test_run_with_params(self, mock_chat_completion):
component = OpenAIGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
)
response = component.run("What's Natural Language Processing?")
# check that the component calls the OpenAI API with the correct parameters
@ -169,7 +181,7 @@ class TestOpenAIGenerator:
assert [isinstance(reply, str) for reply in response["replies"]]
def test_check_abnormal_completions(self, caplog):
component = OpenAIGenerator(api_key="test-api-key")
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
# underlying implementation uses ChatMessage objects so we have to use them here
messages: List[ChatMessage] = []
@ -202,7 +214,7 @@ class TestOpenAIGenerator:
)
@pytest.mark.integration
def test_live_run(self):
component = OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
component = OpenAIGenerator()
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1
assert len(results["meta"]) == 1
@ -224,7 +236,7 @@ class TestOpenAIGenerator:
)
@pytest.mark.integration
def test_live_run_wrong_model(self):
component = OpenAIGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
component = OpenAIGenerator(model="something-obviously-wrong")
with pytest.raises(OpenAIError):
component.run("Whatever")
@ -244,7 +256,7 @@ class TestOpenAIGenerator:
self.responses += chunk.content if chunk.content else ""
callback = Callback()
component = OpenAIGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
component = OpenAIGenerator(streaming_callback=callback)
results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1

View File

@ -1,4 +1,5 @@
from unittest.mock import MagicMock, patch
from haystack.utils.auth import Secret
import pytest
import logging
@ -19,7 +20,7 @@ class TestSimilarityRanker:
"init_parameters": {
"device": None,
"top_k": 10,
"token": None,
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"query_prefix": "",
"document_prefix": "",
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
@ -36,7 +37,7 @@ class TestSimilarityRanker:
component = TransformersSimilarityRanker(
model="my_model",
device=ComponentDevice.from_str("cuda:0"),
token="my_token",
token=Secret.from_env_var("ENV_VAR", strict=False),
top_k=5,
query_prefix="query_instruction: ",
document_prefix="document_instruction: ",
@ -51,7 +52,7 @@ class TestSimilarityRanker:
"init_parameters": {
"device": None,
"model": "my_model",
"token": None, # we don't serialize valid tokens,
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"top_k": 5,
"query_prefix": "query_instruction: ",
"document_prefix": "document_instruction: ",
@ -84,7 +85,7 @@ class TestSimilarityRanker:
"top_k": 10,
"query_prefix": "",
"document_prefix": "",
"token": None,
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
@ -110,7 +111,7 @@ class TestSimilarityRanker:
],
)
def test_to_dict_device_map(self, device_map, expected):
component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map})
component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map}, token=None)
data = component.to_dict()
assert data == {

View File

@ -1,5 +1,6 @@
import os
from unittest.mock import Mock, patch
from haystack.utils.auth import Secret
import pytest
from requests import Timeout, RequestException, HTTPError
@ -369,17 +370,19 @@ def mock_searchapi_search_result():
class TestSearchApiSearchAPI:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("SEARCHAPI_API_KEY", raising=False)
with pytest.raises(ValueError, match="SearchApiWebSearch expects an API key"):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
SearchApiWebSearch()
def test_to_dict(self):
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("SEARCHAPI_API_KEY", "test-api-key")
component = SearchApiWebSearch(
api_key="api_key", top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"}
top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"}
)
data = component.to_dict()
assert data == {
"type": "haystack.components.websearch.searchapi.SearchApiWebSearch",
"init_parameters": {
"api_key": {"env_vars": ["SEARCHAPI_API_KEY"], "strict": True, "type": "env_var"},
"top_k": 10,
"allowed_domains": ["testdomain.com"],
"search_params": {"param": "test params"},
@ -388,7 +391,7 @@ class TestSearchApiSearchAPI:
@pytest.mark.parametrize("top_k", [1, 5, 7])
def test_web_search_top_k(self, mock_searchapi_search_result, top_k: int):
ws = SearchApiWebSearch(api_key="api_key", top_k=top_k)
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"), top_k=top_k)
results = ws.run(query="Who is CEO of Microsoft?")
documents = results["documents"]
links = results["links"]
@ -400,7 +403,7 @@ class TestSearchApiSearchAPI:
@patch("requests.get")
def test_timeout_error(self, mock_get):
mock_get.side_effect = Timeout
ws = SearchApiWebSearch(api_key="api_key")
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(TimeoutError):
ws.run(query="Who is CEO of Microsoft?")
@ -408,7 +411,7 @@ class TestSearchApiSearchAPI:
@patch("requests.get")
def test_request_exception(self, mock_get):
mock_get.side_effect = RequestException
ws = SearchApiWebSearch(api_key="api_key")
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SearchApiError):
ws.run(query="Who is CEO of Microsoft?")
@ -418,7 +421,7 @@ class TestSearchApiSearchAPI:
mock_response = mock_get.return_value
mock_response.status_code = 404
mock_response.raise_for_status.side_effect = HTTPError
ws = SearchApiWebSearch(api_key="api_key")
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SearchApiError):
ws.run(query="Who is CEO of Microsoft?")
@ -429,7 +432,7 @@ class TestSearchApiSearchAPI:
)
@pytest.mark.integration
def test_web_search(self):
ws = SearchApiWebSearch(api_key=os.environ.get("SEARCHAPI_API_KEY", None), top_k=10)
ws = SearchApiWebSearch(top_k=10)
results = ws.run(query="Who is CEO of Microsoft?")
documents = results["documents"]
links = results["links"]

View File

@ -1,5 +1,6 @@
import os
from unittest.mock import Mock, patch
from haystack.utils.auth import Secret
import pytest
from requests import Timeout, RequestException, HTTPError
@ -110,22 +111,26 @@ def mock_serper_dev_search_result():
class TestSerperDevSearchAPI:
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("SERPERDEV_API_KEY", raising=False)
with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
SerperDevWebSearch()
def test_to_dict(self):
component = SerperDevWebSearch(
api_key="test_key", top_k=10, allowed_domains=["test.com"], search_params={"param": "test"}
)
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("SERPERDEV_API_KEY", "test-api-key")
component = SerperDevWebSearch(top_k=10, allowed_domains=["test.com"], search_params={"param": "test"})
data = component.to_dict()
assert data == {
"type": "haystack.components.websearch.serper_dev.SerperDevWebSearch",
"init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}},
"init_parameters": {
"api_key": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
"top_k": 10,
"allowed_domains": ["test.com"],
"search_params": {"param": "test"},
},
}
@pytest.mark.parametrize("top_k", [1, 5, 7])
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):
ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k)
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"), top_k=top_k)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"]
links = results["links"]
@ -137,7 +142,7 @@ class TestSerperDevSearchAPI:
@patch("requests.post")
def test_timeout_error(self, mock_post):
mock_post.side_effect = Timeout
ws = SerperDevWebSearch(api_key="some_invalid_key")
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(TimeoutError):
ws.run(query="Who is the boyfriend of Olivia Wilde?")
@ -145,7 +150,7 @@ class TestSerperDevSearchAPI:
@patch("requests.post")
def test_request_exception(self, mock_post):
mock_post.side_effect = RequestException
ws = SerperDevWebSearch(api_key="some_invalid_key")
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SerperDevError):
ws.run(query="Who is the boyfriend of Olivia Wilde?")
@ -155,7 +160,7 @@ class TestSerperDevSearchAPI:
mock_response = mock_post.return_value
mock_response.status_code = 404
mock_response.raise_for_status.side_effect = HTTPError
ws = SerperDevWebSearch(api_key="some_invalid_key")
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SerperDevError):
ws.run(query="Who is the boyfriend of Olivia Wilde?")
@ -166,7 +171,7 @@ class TestSerperDevSearchAPI:
)
@pytest.mark.integration
def test_web_search(self):
ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10)
ws = SerperDevWebSearch(top_k=10)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"]
links = results["documents"]