feat!: Use Secret for passing authentication secrets to components (#6887)

* feat!: Use `Secret` for passing authentication secrets to components

* Add comment to clarify type ignore
This commit is contained in:
Madeesh Kannan 2024-02-05 13:17:01 +01:00 committed by GitHub
parent 393a7993c3
commit 27d1af3068
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 707 additions and 421 deletions

View File

@ -7,6 +7,7 @@ from openai import OpenAI
from haystack import Document, component, default_from_dict, default_to_dict from haystack import Document, component, default_from_dict, default_to_dict
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +25,7 @@ class RemoteWhisperTranscriber:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "whisper-1", model: str = "whisper-1",
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -61,6 +62,7 @@ class RemoteWhisperTranscriber:
self.organization = organization self.organization = organization
self.model = model self.model = model
self.api_base_url = api_base_url self.api_base_url = api_base_url
self.api_key = api_key
# Only response_format = "json" is supported # Only response_format = "json" is supported
whisper_params = kwargs whisper_params = kwargs
@ -71,7 +73,7 @@ class RemoteWhisperTranscriber:
) )
whisper_params["response_format"] = "json" whisper_params["response_format"] = "json"
self.whisper_params = whisper_params self.whisper_params = whisper_params
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url) self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
@ -81,6 +83,7 @@ class RemoteWhisperTranscriber:
""" """
return default_to_dict( return default_to_dict(
self, self,
api_key=self.api_key.to_dict(),
model=self.model, model=self.model,
organization=self.organization, organization=self.organization,
api_base_url=self.api_base_url, api_base_url=self.api_base_url,
@ -92,6 +95,7 @@ class RemoteWhisperTranscriber:
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data) return default_from_dict(cls, data)
@component.output_types(documents=List[Document]) @component.output_types(documents=List[Document])

View File

@ -29,10 +29,11 @@ class DynamicChatPromptBuilder:
from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage from haystack.dataclasses import ChatMessage
from haystack import Pipeline from haystack import Pipeline
from haystack.utils import Secret
# no parameter init, we don't use any runtime template variables # no parameter init, we don't use any runtime template variables
prompt_builder = DynamicChatPromptBuilder() prompt_builder = DynamicChatPromptBuilder()
llm = OpenAIChatGenerator(api_key="<your-api-key>", model="gpt-3.5-turbo") llm = OpenAIChatGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
pipe = Pipeline() pipe = Pipeline()
pipe.add_component("prompt_builder", prompt_builder) pipe.add_component("prompt_builder", prompt_builder)

View File

@ -21,9 +21,10 @@ class DynamicPromptBuilder:
from haystack.components.builders import DynamicPromptBuilder from haystack.components.builders import DynamicPromptBuilder
from haystack.components.generators import OpenAIGenerator from haystack.components.generators import OpenAIGenerator
from haystack import Pipeline, component, Document from haystack import Pipeline, component, Document
from haystack.utils import Secret
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"]) prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
llm = OpenAIGenerator(api_key="<your-api-key>", model="gpt-3.5-turbo") llm = OpenAIGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
@component @component

View File

@ -1,12 +1,12 @@
from pathlib import Path from pathlib import Path
from typing import List, Union, Dict, Any, Optional from typing import List, Union, Dict, Any, Optional
import logging import logging
import os
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack import component, Document, default_to_dict from haystack import component, Document, default_to_dict, default_from_dict
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream
from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,8 +29,9 @@ class AzureOCRDocumentConverter:
Usage example: Usage example:
```python ```python
from haystack.components.converters.azure import AzureOCRDocumentConverter from haystack.components.converters.azure import AzureOCRDocumentConverter
from haystack.utils import Secret
converter = AzureOCRDocumentConverter() converter = AzureOCRDocumentConverter(endpoint="<url>", api_key=Secret.from_token("<your-api-key>"))
results = converter.run(sources=["image-based-document.pdf"], meta={"date_added": datetime.now().isoformat()}) results = converter.run(sources=["image-based-document.pdf"], meta={"date_added": datetime.now().isoformat()})
documents = results["documents"] documents = results["documents"]
print(documents[0].content) print(documents[0].content)
@ -38,33 +39,23 @@ class AzureOCRDocumentConverter:
``` ```
""" """
def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"): def __init__(
self, endpoint: str, api_key: Secret = Secret.from_env_var("AZURE_AI_API_KEY"), model_id: str = "prebuilt-read"
):
""" """
Create an AzureOCRDocumentConverter component. Create an AzureOCRDocumentConverter component.
:param endpoint: The endpoint of your Azure resource. :param endpoint: The endpoint of your Azure resource.
:param api_key: The key of your Azure resource. It can be :param api_key: The key of your Azure resource.
explicitly provided or automatically read from the
environment variable AZURE_AI_API_KEY (recommended).
:param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature) :param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature)
for a list of available models. Default: `"prebuilt-read"`. for a list of available models. Default: `"prebuilt-read"`.
""" """
azure_import.check() azure_import.check()
api_key = api_key or os.environ.get("AZURE_AI_API_KEY") self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint, credential=AzureKeyCredential(api_key.resolve_value())) # type: ignore
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"AzureOCRDocumentConverter expects an API key. "
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint, credential=AzureKeyCredential(api_key)
)
self.endpoint = endpoint self.endpoint = endpoint
self.model_id = model_id self.model_id = model_id
self.api_key = api_key
@component.output_types(documents=List[Document], raw_azure_response=List[Dict]) @component.output_types(documents=List[Document], raw_azure_response=List[Dict])
def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None): def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None):
@ -116,7 +107,15 @@ class AzureOCRDocumentConverter:
""" """
Serialize this component to a dictionary. Serialize this component to a dictionary.
""" """
return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id) return default_to_dict(self, api_key=self.api_key.to_dict(), endpoint=self.endpoint, model_id=self.model_id)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOCRDocumentConverter":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@staticmethod @staticmethod
def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: Optional[str] = None) -> Document: def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: Optional[str] = None) -> Document:

View File

@ -1,10 +1,11 @@
import os import os
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI from openai.lib.azure import AzureOpenAI
from tqdm import tqdm from tqdm import tqdm
from haystack import component, Document, default_to_dict from haystack import component, Document, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component @component
@ -34,9 +35,8 @@ class AzureOpenAIDocumentEmbedder:
azure_endpoint: Optional[str] = None, azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15", api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002", azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
@ -53,8 +53,6 @@ class AzureOpenAIDocumentEmbedder:
:param azure_deployment: The deployment of the model, usually the model name. :param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication. :param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id :param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See :param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text. :param prefix: A string to add to the beginning of each text.
@ -70,6 +68,11 @@ class AzureOpenAIDocumentEmbedder:
if not azure_endpoint: if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
self.api_key = api_key
self.azure_ad_token = azure_ad_token
self.api_version = api_version self.api_version = api_version
self.azure_endpoint = azure_endpoint self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment self.azure_deployment = azure_deployment
@ -85,9 +88,8 @@ class AzureOpenAIDocumentEmbedder:
api_version=api_version, api_version=api_version,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment, azure_deployment=azure_deployment,
api_key=api_key, api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization, organization=organization,
) )
@ -98,10 +100,6 @@ class AzureOpenAIDocumentEmbedder:
return {"model": self.azure_deployment} return {"model": self.azure_deployment}
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict( return default_to_dict(
self, self,
azure_endpoint=self.azure_endpoint, azure_endpoint=self.azure_endpoint,
@ -114,8 +112,15 @@ class AzureOpenAIDocumentEmbedder:
progress_bar=self.progress_bar, progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed, meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator, embedding_separator=self.embedding_separator,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
""" """
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.

View File

@ -1,9 +1,10 @@
import os import os
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI from openai.lib.azure import AzureOpenAI
from haystack import component, default_to_dict, Document from haystack import component, Document, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component @component
@ -32,9 +33,8 @@ class AzureOpenAITextEmbedder:
azure_endpoint: Optional[str] = None, azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15", api_version: Optional[str] = "2023-05-15",
azure_deployment: str = "text-embedding-ada-002", azure_deployment: str = "text-embedding-ada-002",
api_key: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
@ -47,8 +47,6 @@ class AzureOpenAITextEmbedder:
:param azure_deployment: The deployment of the model, usually the model name. :param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication. :param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id :param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See :param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param prefix: A string to add to the beginning of each text. :param prefix: A string to add to the beginning of each text.
@ -62,6 +60,11 @@ class AzureOpenAITextEmbedder:
if not azure_endpoint: if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
self.api_key = api_key
self.azure_ad_token = azure_ad_token
self.api_version = api_version self.api_version = api_version
self.azure_endpoint = azure_endpoint self.azure_endpoint = azure_endpoint
self.azure_deployment = azure_deployment self.azure_deployment = azure_deployment
@ -73,9 +76,8 @@ class AzureOpenAITextEmbedder:
api_version=api_version, api_version=api_version,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment, azure_deployment=azure_deployment,
api_key=api_key, api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization, organization=organization,
) )
@ -98,8 +100,15 @@ class AzureOpenAITextEmbedder:
api_version=self.api_version, api_version=self.api_version,
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAITextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
return default_from_dict(cls, data)
@component.output_types(embedding=List[float], meta=Dict[str, Any]) @component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str): def run(self, text: str):
"""Embed a string using AzureOpenAITextEmbedder.""" """Embed a string using AzureOpenAITextEmbedder."""

View File

@ -1,6 +1,7 @@
from typing import List, Optional, Union, Dict from typing import List, Optional, Dict
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as sentence_transformers_import: with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as sentence_transformers_import:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
@ -14,14 +15,12 @@ class _SentenceTransformersEmbeddingBackendFactory:
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {} _instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}
@staticmethod @staticmethod
def get_embedding_backend(model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None): def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
embedding_backend_id = f"{model}{device}{use_auth_token}" embedding_backend_id = f"{model}{device}{auth_token}"
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances: if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend( embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
model=model, device=device, use_auth_token=use_auth_token
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend return embedding_backend
@ -31,9 +30,11 @@ class _SentenceTransformersEmbeddingBackend:
Class to manage Sentence Transformers embeddings. Class to manage Sentence Transformers embeddings.
""" """
def __init__(self, model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None): def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
sentence_transformers_import.check() sentence_transformers_import.check()
self.model = SentenceTransformer(model_name_or_path=model, device=device, use_auth_token=use_auth_token) self.model = SentenceTransformer(
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
)
def embed(self, data: List[str], **kwargs) -> List[List[float]]: def embed(self, data: List[str], **kwargs) -> List[List[float]]:
embeddings = self.model.encode(data, **kwargs).tolist() embeddings = self.model.encode(data, **kwargs).tolist()

View File

@ -1,26 +1,27 @@
from typing import Optional from typing import Optional
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
with LazyImport(message="Run 'pip install transformers'") as transformers_import: with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
def check_valid_model(model_id: str, token: Optional[str]) -> None: def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
""" """
Check if the provided model ID corresponds to a valid model on HuggingFace Hub. Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a embedding model. Also check if the model is a embedding model.
:param model_id: A string representing the HuggingFace model ID. :param model_id: A string representing the HuggingFace model ID.
:param token: An optional string representing the authentication token. :param token: The optional authentication token.
:raises ValueError: If the model is not found or is not a embedding model. :raises ValueError: If the model is not found or is not a embedding model.
""" """
transformers_import.check() transformers_import.check()
api = HfApi() api = HfApi()
try: try:
model_info = api.model_info(model_id, token=token) model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e: except RepositoryNotFoundError as e:
raise ValueError( raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id." f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."

View File

@ -1,14 +1,14 @@
import logging import logging
import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from tqdm import tqdm from tqdm import tqdm
from haystack import component, default_to_dict
from haystack.components.embedders.hf_utils import check_valid_model from haystack.components.embedders.hf_utils import check_valid_model
from haystack.dataclasses import Document from haystack.dataclasses import Document
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack import component, default_to_dict, default_from_dict
with LazyImport(message="Run 'pip install transformers'") as transformers_import: with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
@ -29,11 +29,12 @@ class HuggingFaceTEIDocumentEmbedder:
```python ```python
from haystack.dataclasses import Document from haystack.dataclasses import Document
from haystack.components.embedders import HuggingFaceTEIDocumentEmbedder from haystack.components.embedders import HuggingFaceTEIDocumentEmbedder
from haystack.utils import Secret
doc = Document(content="I love pizza!") doc = Document(content="I love pizza!")
document_embedder = HuggingFaceTEIDocumentEmbedder( document_embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", token="<your-token>" model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
) )
result = document_embedder.run([doc]) result = document_embedder.run([doc])
@ -52,7 +53,7 @@ class HuggingFaceTEIDocumentEmbedder:
doc = Document(content="I love pizza!") doc = Document(content="I love pizza!")
document_embedder = HuggingFaceTEIDocumentEmbedder( document_embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token="<your-token>" model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token=Secret.from_token("<your-api-key>")
) )
result = document_embedder.run([doc]) result = document_embedder.run([doc])
@ -83,7 +84,7 @@ class HuggingFaceTEIDocumentEmbedder:
self, self,
model: str = "BAAI/bge-small-en-v1.5", model: str = "BAAI/bge-small-en-v1.5",
url: Optional[str] = None, url: Optional[str] = None,
token: Optional[str] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
batch_size: int = 32, batch_size: int = 32,
@ -98,8 +99,7 @@ class HuggingFaceTEIDocumentEmbedder:
:param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference :param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
Endpoint. Endpoint.
:param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving :param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
a private or gated model. It can be explicitly provided or automatically read from the environment a private or gated model.
variable HF_API_TOKEN (recommended).
:param prefix: A string to add to the beginning of each text. :param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text. :param suffix: A string to add to the end of each text.
:param batch_size: Number of Documents to encode at once. :param batch_size: Number of Documents to encode at once.
@ -116,15 +116,12 @@ class HuggingFaceTEIDocumentEmbedder:
if not is_valid_url: if not is_valid_url:
raise ValueError(f"Invalid TEI endpoint URL provided: {url}") raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
# The user does not need to provide a token if it is a local server or free public HF Inference Endpoint.
token = token or os.environ.get("HF_API_TOKEN")
check_valid_model(model, token) check_valid_model(model, token)
self.model = model self.model = model
self.url = url self.url = url
self.token = token self.token = token
self.client = InferenceClient(url or model, token=token) self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
self.batch_size = batch_size self.batch_size = batch_size
@ -133,10 +130,6 @@ class HuggingFaceTEIDocumentEmbedder:
self.embedding_separator = embedding_separator self.embedding_separator = embedding_separator
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `token` value passed
to the constructor.
"""
return default_to_dict( return default_to_dict(
self, self,
model=self.model, model=self.model,
@ -147,8 +140,14 @@ class HuggingFaceTEIDocumentEmbedder:
progress_bar=self.progress_bar, progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed, meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator, embedding_separator=self.embedding_separator,
token=self.token.to_dict() if self.token else None,
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
Data that is sent to Posthog for usage analytics. Data that is sent to Posthog for usage analytics.

View File

@ -1,11 +1,11 @@
import logging import logging
import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from haystack import component, default_to_dict from haystack import component, default_to_dict, default_from_dict
from haystack.components.embedders.hf_utils import check_valid_model from haystack.components.embedders.hf_utils import check_valid_model
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers'") as transformers_import: with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
@ -23,11 +23,12 @@ class HuggingFaceTEITextEmbedder:
Inference API tier: Inference API tier:
```python ```python
from haystack.components.embedders import HuggingFaceTEITextEmbedder from haystack.components.embedders import HuggingFaceTEITextEmbedder
from haystack.utils import Secret
text_to_embed = "I love pizza!" text_to_embed = "I love pizza!"
text_embedder = HuggingFaceTEITextEmbedder( text_embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", token="<your-token>" model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
) )
print(text_embedder.run(text_to_embed)) print(text_embedder.run(text_to_embed))
@ -44,7 +45,7 @@ class HuggingFaceTEITextEmbedder:
text_to_embed = "I love pizza!" text_to_embed = "I love pizza!"
text_embedder = HuggingFaceTEITextEmbedder( text_embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token="<your-token>" model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token=Secret.from_token("<your-api-key>")
) )
print(text_embedder.run(text_to_embed)) print(text_embedder.run(text_to_embed))
@ -74,7 +75,7 @@ class HuggingFaceTEITextEmbedder:
self, self,
model: str = "BAAI/bge-small-en-v1.5", model: str = "BAAI/bge-small-en-v1.5",
url: Optional[str] = None, url: Optional[str] = None,
token: Optional[str] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
): ):
@ -85,8 +86,7 @@ class HuggingFaceTEITextEmbedder:
:param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference :param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
Endpoint. Endpoint.
:param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving :param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
a private or gated model. It can be explicitly provided or automatically read from the environment a private or gated model.
variable HF_API_TOKEN (recommended).
:param prefix: A string to add to the beginning of each text. :param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text. :param suffix: A string to add to the end of each text.
""" """
@ -98,24 +98,29 @@ class HuggingFaceTEITextEmbedder:
if not is_valid_url: if not is_valid_url:
raise ValueError(f"Invalid TEI endpoint URL provided: {url}") raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
# The user does not need to provide a token if it is a local server or free public HF Inference Endpoint.
token = token or os.environ.get("HF_API_TOKEN")
check_valid_model(model, token) check_valid_model(model, token)
self.model = model self.model = model
self.url = url self.url = url
self.token = token self.token = token
self.client = InferenceClient(url or model, token=token) self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" return default_to_dict(
This method overrides the default serializer in order to avoid leaking the `token` value passed self,
to the constructor. model=self.model,
""" url=self.url,
return default_to_dict(self, model=self.model, url=self.url, prefix=self.prefix, suffix=self.suffix) prefix=self.prefix,
suffix=self.suffix,
token=self.token.to_dict() if self.token else None,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEITextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """

View File

@ -3,7 +3,8 @@ from typing import List, Optional, Dict, Any, Tuple
from openai import OpenAI from openai import OpenAI
from tqdm import tqdm from tqdm import tqdm
from haystack import component, Document, default_to_dict from haystack import component, Document, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component @component
@ -30,7 +31,7 @@ class OpenAIDocumentEmbedder:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002", model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -43,8 +44,7 @@ class OpenAIDocumentEmbedder:
): ):
""" """
Create a OpenAIDocumentEmbedder component. Create a OpenAIDocumentEmbedder component.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the :param api_key: The OpenAI API key.
environment variable OPENAI_API_KEY (recommended).
:param model: The name of the model to use. :param model: The name of the model to use.
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio). :param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
:param organization: The Organization ID, defaults to `None`. See :param organization: The Organization ID, defaults to `None`. See
@ -57,6 +57,7 @@ class OpenAIDocumentEmbedder:
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text.
""" """
self.api_key = api_key
self.model = model self.model = model
self.api_base_url = api_base_url self.api_base_url = api_base_url
self.organization = organization self.organization = organization
@ -67,7 +68,7 @@ class OpenAIDocumentEmbedder:
self.meta_fields_to_embed = meta_fields_to_embed or [] self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator self.embedding_separator = embedding_separator
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url) self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -91,8 +92,14 @@ class OpenAIDocumentEmbedder:
progress_bar=self.progress_bar, progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed, meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator, embedding_separator=self.embedding_separator,
api_key=self.api_key.to_dict(),
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
""" """
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.

View File

@ -2,7 +2,8 @@ from typing import List, Optional, Dict, Any
from openai import OpenAI from openai import OpenAI
from haystack import component, default_to_dict from haystack import component, default_to_dict, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
@component @component
@ -28,7 +29,7 @@ class OpenAITextEmbedder:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002", model: str = "text-embedding-ada-002",
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
@ -38,8 +39,7 @@ class OpenAITextEmbedder:
""" """
Create an OpenAITextEmbedder component. Create an OpenAITextEmbedder component.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the :param api_key: The OpenAI API key.
environment variable OPENAI_API_KEY (recommended).
:param model: The name of the OpenAI model to use. For more details on the available models, :param model: The name of the OpenAI model to use. For more details on the available models,
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models). see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
:param organization: The Organization ID, defaults to `None`. See :param organization: The Organization ID, defaults to `None`. See
@ -52,8 +52,9 @@ class OpenAITextEmbedder:
self.organization = organization self.organization = organization
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
self.api_key = api_key
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url) self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -62,15 +63,20 @@ class OpenAITextEmbedder:
return {"model": self.model} return {"model": self.model}
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict( return default_to_dict(
self, model=self.model, organization=self.organization, prefix=self.prefix, suffix=self.suffix self,
model=self.model,
organization=self.organization,
prefix=self.prefix,
suffix=self.suffix,
api_key=self.api_key.to_dict(),
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "OpenAITextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(embedding=List[float], meta=Dict[str, Any]) @component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str): def run(self, text: str):
"""Embed a string.""" """Embed a string."""

View File

@ -1,9 +1,10 @@
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Dict, Any
from haystack import component, Document, default_to_dict from haystack import component, Document, default_to_dict, default_from_dict
from haystack.components.embedders.backends.sentence_transformers_backend import ( from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory, _SentenceTransformersEmbeddingBackendFactory,
) )
from haystack.utils import Secret, deserialize_secrets_inplace
@component @component
@ -31,7 +32,7 @@ class SentenceTransformersDocumentEmbedder:
self, self,
model: str = "sentence-transformers/all-mpnet-base-v2", model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None, device: Optional[str] = None,
token: Union[bool, str, None] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
batch_size: int = 32, batch_size: int = 32,
@ -48,8 +49,6 @@ class SentenceTransformersDocumentEmbedder:
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. :param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU. Defaults to CPU.
:param token: The API token used to download private models from Hugging Face. :param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param prefix: A string to add to the beginning of each Document text before embedding. :param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models, Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and bge. such as E5 and bge.
@ -87,7 +86,7 @@ class SentenceTransformersDocumentEmbedder:
self, self,
model=self.model, model=self.model,
device=self.device, device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens token=self.token.to_dict() if self.token else None,
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,
batch_size=self.batch_size, batch_size=self.batch_size,
@ -97,13 +96,18 @@ class SentenceTransformersDocumentEmbedder:
embedding_separator=self.embedding_separator, embedding_separator=self.embedding_separator,
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def warm_up(self): def warm_up(self):
""" """
Load the embedding backend. Load the embedding backend.
""" """
if not hasattr(self, "embedding_backend"): if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device, use_auth_token=self.token model=self.model, device=self.device, auth_token=self.token
) )
@component.output_types(documents=List[Document]) @component.output_types(documents=List[Document])

View File

@ -1,9 +1,10 @@
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Dict, Any
from haystack import component, default_to_dict from haystack import component, default_to_dict, default_from_dict
from haystack.components.embedders.backends.sentence_transformers_backend import ( from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory, _SentenceTransformersEmbeddingBackendFactory,
) )
from haystack.utils import Secret, deserialize_secrets_inplace
@component @component
@ -30,7 +31,7 @@ class SentenceTransformersTextEmbedder:
self, self,
model: str = "sentence-transformers/all-mpnet-base-v2", model: str = "sentence-transformers/all-mpnet-base-v2",
device: Optional[str] = None, device: Optional[str] = None,
token: Union[bool, str, None] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "", prefix: str = "",
suffix: str = "", suffix: str = "",
batch_size: int = 32, batch_size: int = 32,
@ -45,8 +46,6 @@ class SentenceTransformersTextEmbedder:
:param device: Device (like 'cuda' / 'cpu') that should be used for computation. :param device: Device (like 'cuda' / 'cpu') that should be used for computation.
Defaults to CPU. Defaults to CPU.
:param token: The API token used to download private models from Hugging Face. :param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running
`transformers-cli login` (stored in ~/.huggingface) will be used.
:param prefix: A string to add to the beginning of each Document text before embedding. :param prefix: A string to add to the beginning of each Document text before embedding.
Can be used to prepend the text with an instruction, as required by some embedding models, Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and bge. such as E5 and bge.
@ -80,7 +79,7 @@ class SentenceTransformersTextEmbedder:
self, self,
model=self.model, model=self.model,
device=self.device, device=self.device,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens token=self.token.to_dict() if self.token else None,
prefix=self.prefix, prefix=self.prefix,
suffix=self.suffix, suffix=self.suffix,
batch_size=self.batch_size, batch_size=self.batch_size,
@ -88,13 +87,18 @@ class SentenceTransformersTextEmbedder:
normalize_embeddings=self.normalize_embeddings, normalize_embeddings=self.normalize_embeddings,
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def warm_up(self): def warm_up(self):
""" """
Load the embedding backend. Load the embedding backend.
""" """
if not hasattr(self, "embedding_backend"): if not hasattr(self, "embedding_backend"):
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model=self.model, device=self.device, use_auth_token=self.token model=self.model, device=self.device, auth_token=self.token
) )
@component.output_types(embedding=List[float]) @component.output_types(embedding=List[float])

View File

@ -3,12 +3,13 @@ import os
from typing import Optional, Callable, Dict, Any from typing import Optional, Callable, Dict, Any
# pylint: disable=import-error # pylint: disable=import-error
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI from openai.lib.azure import AzureOpenAI
from haystack import default_to_dict, default_from_dict from haystack import default_to_dict, default_from_dict
from haystack.components.generators import OpenAIGenerator from haystack.components.generators import OpenAIGenerator
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,8 +29,9 @@ class AzureOpenAIGenerator(OpenAIGenerator):
```python ```python
from haystack.components.generators import AzureOpenAIGenerator from haystack.components.generators import AzureOpenAIGenerator
from haystack.utils import Secret
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>", client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
api_key="<you api key>", api_key=Secret.from_token("<your-api-key>"),
azure_deployment="<this a model name, e.g. gpt-35-turbo>") azure_deployment="<this a model name, e.g. gpt-35-turbo>")
response = client.run("What's Natural Language Processing? Be brief.") response = client.run("What's Natural Language Processing? Be brief.")
print(response) print(response)
@ -56,9 +58,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
azure_endpoint: Optional[str] = None, azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15", api_version: Optional[str] = "2023-05-15",
azure_deployment: Optional[str] = "gpt-35-turbo", azure_deployment: Optional[str] = "gpt-35-turbo",
api_key: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
system_prompt: Optional[str] = None, system_prompt: Optional[str] = None,
@ -70,8 +71,6 @@ class AzureOpenAIGenerator(OpenAIGenerator):
:param azure_deployment: The deployment of the model, usually the model name. :param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication. :param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id :param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See :param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param streaming_callback: A callback function that is called when a new token is received from the stream. :param streaming_callback: A callback function that is called when a new token is received from the stream.
@ -108,6 +107,13 @@ class AzureOpenAIGenerator(OpenAIGenerator):
if not azure_endpoint: if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
# The check above makes mypy incorrectly infer that api_key is never None,
# which propagates the incorrect type.
self.api_key = api_key # type: ignore
self.azure_ad_token = azure_ad_token
self.generation_kwargs = generation_kwargs or {} self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt self.system_prompt = system_prompt
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
@ -121,9 +127,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
api_version=api_version, api_version=api_version,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment, azure_deployment=azure_deployment,
api_key=api_key, api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization, organization=organization,
) )
@ -142,6 +147,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
streaming_callback=callback_name, streaming_callback=callback_name,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
) )
@classmethod @classmethod
@ -151,6 +158,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
:param data: The dictionary representation of this component. :param data: The dictionary representation of this component.
:return: The deserialized component instance. :return: The deserialized component instance.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:

View File

@ -3,12 +3,13 @@ import os
from typing import Optional, Callable, Dict, Any from typing import Optional, Callable, Dict, Any
# pylint: disable=import-error # pylint: disable=import-error
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI from openai.lib.azure import AzureOpenAI
from haystack import default_to_dict, default_from_dict from haystack import default_to_dict, default_from_dict
from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,11 +29,12 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
```python ```python
from haystack.components.generators.chat import AzureOpenAIGenerator from haystack.components.generators.chat import AzureOpenAIGenerator
from haystack.dataclasses import ChatMessage from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
messages = [ChatMessage.from_user("What's Natural Language Processing?")] messages = [ChatMessage.from_user("What's Natural Language Processing?")]
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>", client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
api_key="<you api key>", api_key=Secret.from_token("<your-api-key>"),
azure_deployment="<this a model name, e.g. gpt-35-turbo>") azure_deployment="<this a model name, e.g. gpt-35-turbo>")
response = client.run(messages) response = client.run(messages)
print(response) print(response)
@ -63,9 +65,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
azure_endpoint: Optional[str] = None, azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15", api_version: Optional[str] = "2023-05-15",
azure_deployment: Optional[str] = "gpt-35-turbo", azure_deployment: Optional[str] = "gpt-35-turbo",
api_key: Optional[str] = None, api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
organization: Optional[str] = None, organization: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
@ -76,8 +77,6 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
:param azure_deployment: The deployment of the model, usually the model name. :param azure_deployment: The deployment of the model, usually the model name.
:param api_key: The API key to use for authentication. :param api_key: The API key to use for authentication.
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id :param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
on every request.
:param organization: The Organization ID, defaults to `None`. See :param organization: The Organization ID, defaults to `None`. See
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization). [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
:param streaming_callback: A callback function that is called when a new token is received from the stream. :param streaming_callback: A callback function that is called when a new token is received from the stream.
@ -113,6 +112,13 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
if not azure_endpoint: if not azure_endpoint:
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.") raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
if api_key is None and azure_ad_token is None:
raise ValueError("Please provide an API key or an Azure Active Directory token.")
# The check above makes mypy incorrectly infer that api_key is never None,
# which propagates the incorrect type.
self.api_key = api_key # type: ignore
self.azure_ad_token = azure_ad_token
self.generation_kwargs = generation_kwargs or {} self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.api_version = api_version self.api_version = api_version
@ -125,9 +131,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
api_version=api_version, api_version=api_version,
azure_endpoint=azure_endpoint, azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment, azure_deployment=azure_deployment,
api_key=api_key, api_key=api_key.resolve_value() if api_key is not None else None,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
azure_ad_token_provider=azure_ad_token_provider,
organization=organization, organization=organization,
) )
@ -145,6 +150,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
api_version=self.api_version, api_version=self.api_version,
streaming_callback=callback_name, streaming_callback=callback_name,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict() if self.api_key is not None else None,
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
) )
@classmethod @classmethod
@ -154,6 +161,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
:param data: The dictionary representation of this component. :param data: The dictionary representation of this component.
:return: The deserialized component instance. :return: The deserialized component instance.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:

View File

@ -9,6 +9,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice from haystack.utils import ComponentDevice
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +58,7 @@ class HuggingFaceLocalChatGenerator:
model: str = "HuggingFaceH4/zephyr-7b-beta", model: str = "HuggingFaceH4/zephyr-7b-beta",
task: Optional[Literal["text-generation", "text2text-generation"]] = None, task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[ComponentDevice] = None, device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
@ -80,7 +81,6 @@ class HuggingFaceLocalChatGenerator:
:param device: The device on which the model is loaded. If `None`, the default device is automatically :param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files. :param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat :param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
messages. While high-quality and well-supported chat models typically include their own chat templates messages. While high-quality and well-supported chat models typically include their own chat templates
@ -113,6 +113,9 @@ class HuggingFaceLocalChatGenerator:
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {} generation_kwargs = generation_kwargs or {}
self.token = token
token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters # check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters # otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model) huggingface_pipeline_kwargs.setdefault("model", model)
@ -178,12 +181,11 @@ class HuggingFaceLocalChatGenerator:
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name, streaming_callback=callback_name,
token=self.token.to_dict() if self.token else None,
) )
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
# we don't want to serialize valid tokens huggingface_pipeline_kwargs.pop("token", None)
if isinstance(huggingface_pipeline_kwargs["token"], str):
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
serialize_hf_model_kwargs(huggingface_pipeline_kwargs) serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict return serialization_dict
@ -194,7 +196,7 @@ class HuggingFaceLocalChatGenerator:
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
torch_and_transformers_import.check() # leave this, cls method torch_and_transformers_import.check() # leave this, cls method
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:

View File

@ -8,6 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.components.generators.hf_utils import check_valid_model, check_generation_params from haystack.components.generators.hf_utils import check_valid_model, check_generation_params
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers'") as transformers_import: with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
@ -28,12 +29,13 @@ class HuggingFaceTGIChatGenerator:
```python ```python
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
from haystack.dataclasses import ChatMessage from haystack.dataclasses import ChatMessage
from haystack.utils import Secret
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")] ChatMessage.from_user("What's Natural Language Processing?")]
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="<your-token>") client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token=Secret.from_token("<your-api-key>"))
client.warm_up() client.warm_up()
response = client.run(messages, generation_kwargs={"max_new_tokens": 120}) response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
print(response) print(response)
@ -51,7 +53,7 @@ class HuggingFaceTGIChatGenerator:
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf",
url="<your-tgi-endpoint-url>", url="<your-tgi-endpoint-url>",
token="<your-token>") token=Secret.from_token("<your-api-key>"))
client.warm_up() client.warm_up()
response = client.run(messages, generation_kwargs={"max_new_tokens": 120}) response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
print(response) print(response)
@ -85,7 +87,7 @@ class HuggingFaceTGIChatGenerator:
self, self,
model: str = "meta-llama/Llama-2-13b-chat-hf", model: str = "meta-llama/Llama-2-13b-chat-hf",
url: Optional[str] = None, url: Optional[str] = None,
token: Optional[str] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None, stop_words: Optional[List[str]] = None,
@ -131,7 +133,7 @@ class HuggingFaceTGIChatGenerator:
self.chat_template = chat_template self.chat_template = chat_template
self.token = token self.token = token
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token) self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.tokenizer = None self.tokenizer = None
@ -139,7 +141,9 @@ class HuggingFaceTGIChatGenerator:
""" """
Load the tokenizer. Load the tokenizer.
""" """
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token) self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
)
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained # mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
chat_template = getattr(self.tokenizer, "chat_template", None) chat_template = getattr(self.tokenizer, "chat_template", None)
if not chat_template and not self.chat_template: if not chat_template and not self.chat_template:
@ -162,7 +166,7 @@ class HuggingFaceTGIChatGenerator:
model=self.model, model=self.model,
url=self.url, url=self.url,
chat_template=self.chat_template, chat_template=self.chat_template,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name, streaming_callback=callback_name,
) )
@ -172,6 +176,7 @@ class HuggingFaceTGIChatGenerator:
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:

View File

@ -13,6 +13,7 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk, ChatMessage from haystack.dataclasses import StreamingChunk, ChatMessage
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -62,7 +63,7 @@ class OpenAIChatGenerator:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
@ -73,8 +74,7 @@ class OpenAIChatGenerator:
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model. GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the :param api_key: The OpenAI API key.
environment variable OPENAI_API_KEY (recommended).
:param model: The name of the model to use. :param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream. :param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument. The callback function accepts StreamingChunk as an argument.
@ -101,12 +101,13 @@ class OpenAIChatGenerator:
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token. values are the bias to add to that token.
""" """
self.api_key = api_key
self.model = model self.model = model
self.generation_kwargs = generation_kwargs or {} self.generation_kwargs = generation_kwargs or {}
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.api_base_url = api_base_url self.api_base_url = api_base_url
self.organization = organization self.organization = organization
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url) self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -127,6 +128,7 @@ class OpenAIChatGenerator:
api_base_url=self.api_base_url, api_base_url=self.api_base_url,
organization=self.organization, organization=self.organization,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
) )
@classmethod @classmethod
@ -136,6 +138,7 @@ class OpenAIChatGenerator:
:param data: The dictionary representation of this component. :param data: The dictionary representation of this component.
:return: The deserialized component instance. :return: The deserialized component instance.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:
@ -334,7 +337,7 @@ class OpenAIChatGenerator:
class GPTChatGenerator(OpenAIChatGenerator): class GPTChatGenerator(OpenAIChatGenerator):
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union, Callable
from haystack.dataclasses import StreamingChunk from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret
with LazyImport(message="Run 'pip install transformers'") as transformers_import: with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient, HfApi from huggingface_hub import InferenceClient, HfApi
@ -36,20 +37,20 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
) )
def check_valid_model(model_id: str, token: Optional[str]) -> None: def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
""" """
Check if the provided model ID corresponds to a valid model on HuggingFace Hub. Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
Also check if the model is a text generation model. Also check if the model is a text generation model.
:param model_id: A string representing the HuggingFace model ID. :param model_id: A string representing the HuggingFace model ID.
:param token: An optional string representing the authentication token. :param token: An optional authentication token.
:raises ValueError: If the model is not found or is not a text generation model. :raises ValueError: If the model is not found or is not a text generation model.
""" """
transformers_import.check() transformers_import.check()
api = HfApi() api = HfApi()
try: try:
model_info = api.model_info(model_id, token=token) model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
except RepositoryNotFoundError as e: except RepositoryNotFoundError as e:
raise ValueError( raise ValueError(
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id." f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."

View File

@ -1,11 +1,12 @@
import logging import logging
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice from haystack.utils import ComponentDevice
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,7 +45,7 @@ class HuggingFaceLocalGenerator:
model: str = "google/flan-t5-base", model: str = "google/flan-t5-base",
task: Optional[Literal["text-generation", "text2text-generation"]] = None, task: Optional[Literal["text-generation", "text2text-generation"]] = None,
device: Optional[ComponentDevice] = None, device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None, huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None, stop_words: Optional[List[str]] = None,
@ -63,7 +64,6 @@ class HuggingFaceLocalGenerator:
:param device: The device on which the model is loaded. If `None`, the default device is automatically :param device: The device on which the model is loaded. If `None`, the default device is automatically
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter. selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
:param token: The token to use as HTTP bearer authorization for remote files. :param token: The token to use as HTTP bearer authorization for remote files.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored. If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation. :param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,... Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
@ -89,6 +89,9 @@ class HuggingFaceLocalGenerator:
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {} huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
generation_kwargs = generation_kwargs or {} generation_kwargs = generation_kwargs or {}
self.token = token
token = token.resolve_value() if token else None
# check if the huggingface_pipeline_kwargs contain the essential parameters # check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from other init parameters # otherwise, populate them with values from other init parameters
huggingface_pipeline_kwargs.setdefault("model", model) huggingface_pipeline_kwargs.setdefault("model", model)
@ -156,12 +159,11 @@ class HuggingFaceLocalGenerator:
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs, huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
stop_words=self.stop_words, stop_words=self.stop_words,
token=self.token.to_dict() if self.token else None,
) )
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"] huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
# we don't want to serialize valid tokens huggingface_pipeline_kwargs.pop("token", None)
if isinstance(huggingface_pipeline_kwargs["token"], str):
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
serialize_hf_model_kwargs(huggingface_pipeline_kwargs) serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
return serialization_dict return serialization_dict
@ -171,6 +173,7 @@ class HuggingFaceLocalGenerator:
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"]) deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
return default_from_dict(cls, data) return default_from_dict(cls, data)

View File

@ -8,6 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
from haystack.dataclasses import StreamingChunk from haystack.dataclasses import StreamingChunk
from haystack.components.generators.hf_utils import check_generation_params, check_valid_model from haystack.components.generators.hf_utils import check_generation_params, check_valid_model
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install transformers'") as transformers_import: with LazyImport(message="Run 'pip install transformers'") as transformers_import:
from huggingface_hub import InferenceClient from huggingface_hub import InferenceClient
@ -29,7 +30,9 @@ class HuggingFaceTGIGenerator:
```python ```python
from haystack.components.generators import HuggingFaceTGIGenerator from haystack.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="<your-token>") from haystack.utils import Secret
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>"))
client.warm_up() client.warm_up()
response = client.run("What's Natural Language Processing?", max_new_tokens=120) response = client.run("What's Natural Language Processing?", max_new_tokens=120)
print(response) print(response)
@ -42,7 +45,7 @@ class HuggingFaceTGIGenerator:
from haystack.components.generators import HuggingFaceTGIGenerator from haystack.components.generators import HuggingFaceTGIGenerator
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1",
url="<your-tgi-endpoint-url>", url="<your-tgi-endpoint-url>",
token="<your-token>") token=Secret.from_token("<your-api-key>"))
client.warm_up() client.warm_up()
response = client.run("What's Natural Language Processing?") response = client.run("What's Natural Language Processing?")
print(response) print(response)
@ -74,7 +77,7 @@ class HuggingFaceTGIGenerator:
self, self,
model: str = "mistralai/Mistral-7B-v0.1", model: str = "mistralai/Mistral-7B-v0.1",
url: Optional[str] = None, url: Optional[str] = None,
token: Optional[str] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
generation_kwargs: Optional[Dict[str, Any]] = None, generation_kwargs: Optional[Dict[str, Any]] = None,
stop_words: Optional[List[str]] = None, stop_words: Optional[List[str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
@ -113,7 +116,7 @@ class HuggingFaceTGIGenerator:
self.url = url self.url = url
self.token = token self.token = token
self.generation_kwargs = generation_kwargs self.generation_kwargs = generation_kwargs
self.client = InferenceClient(url or model, token=token) self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
self.tokenizer = None self.tokenizer = None
@ -121,7 +124,9 @@ class HuggingFaceTGIGenerator:
""" """
Load the tokenizer Load the tokenizer
""" """
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token) self.tokenizer = AutoTokenizer.from_pretrained(
self.model, token=self.token.resolve_value() if self.token else None
)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
@ -134,7 +139,7 @@ class HuggingFaceTGIGenerator:
self, self,
model=self.model, model=self.model,
url=self.url, url=self.url,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens token=self.token.to_dict() if self.token else None,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
streaming_callback=callback_name, streaming_callback=callback_name,
) )
@ -144,6 +149,7 @@ class HuggingFaceTGIGenerator:
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:

View File

@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletionChunk, ChatCompletion
from haystack import component, default_from_dict, default_to_dict from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
from haystack.dataclasses import StreamingChunk, ChatMessage from haystack.dataclasses import StreamingChunk, ChatMessage
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,7 +51,7 @@ class OpenAIGenerator:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,
@ -62,8 +63,7 @@ class OpenAIGenerator:
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
GPT-3.5 model. GPT-3.5 model.
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the :param api_key: The OpenAI API key.
environment variable OPENAI_API_KEY (recommended).
:param model: The name of the model to use. :param model: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream. :param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument. The callback function accepts StreamingChunk as an argument.
@ -92,6 +92,7 @@ class OpenAIGenerator:
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
values are the bias to add to that token. values are the bias to add to that token.
""" """
self.api_key = api_key
self.model = model self.model = model
self.generation_kwargs = generation_kwargs or {} self.generation_kwargs = generation_kwargs or {}
self.system_prompt = system_prompt self.system_prompt = system_prompt
@ -99,7 +100,7 @@ class OpenAIGenerator:
self.api_base_url = api_base_url self.api_base_url = api_base_url
self.organization = organization self.organization = organization
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url) self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
def _get_telemetry_data(self) -> Dict[str, Any]: def _get_telemetry_data(self) -> Dict[str, Any]:
""" """
@ -120,6 +121,7 @@ class OpenAIGenerator:
api_base_url=self.api_base_url, api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs, generation_kwargs=self.generation_kwargs,
system_prompt=self.system_prompt, system_prompt=self.system_prompt,
api_key=self.api_key.to_dict(),
) )
@classmethod @classmethod
@ -129,6 +131,7 @@ class OpenAIGenerator:
:param data: The dictionary representation of this component. :param data: The dictionary representation of this component.
:return: The deserialized component instance. :return: The deserialized component instance.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
init_params = data.get("init_parameters", {}) init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback") serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler: if serialized_callback_handler:
@ -279,7 +282,7 @@ class OpenAIGenerator:
class GPTGenerator(OpenAIGenerator): class GPTGenerator(OpenAIGenerator):
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None, api_base_url: Optional[str] = None,

View File

@ -6,6 +6,7 @@ from haystack import ComponentError, Document, component, default_from_dict, def
from haystack.lazy_imports import LazyImport from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, DeviceMap from haystack.utils import ComponentDevice, DeviceMap
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs, resolve_hf_device_map from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs, resolve_hf_device_map
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,7 +41,7 @@ class TransformersSimilarityRanker:
self, self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[ComponentDevice] = None, device: Optional[ComponentDevice] = None,
token: Union[bool, str, None] = None, token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
top_k: int = 10, top_k: int = 10,
query_prefix: str = "", query_prefix: str = "",
document_prefix: str = "", document_prefix: str = "",
@ -119,9 +120,11 @@ class TransformersSimilarityRanker:
""" """
if self.model is None: if self.model is None:
self.model = AutoModelForSequenceClassification.from_pretrained( self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name_or_path, token=self.token, **self.model_kwargs self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
) )
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map)) self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@ -132,7 +135,7 @@ class TransformersSimilarityRanker:
self, self,
device=None, device=None,
model=self.model_name_or_path, model=self.model_name_or_path,
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens token=self.token.to_dict() if self.token else None,
top_k=self.top_k, top_k=self.top_k,
query_prefix=self.query_prefix, query_prefix=self.query_prefix,
document_prefix=self.document_prefix, document_prefix=self.document_prefix,
@ -152,6 +155,7 @@ class TransformersSimilarityRanker:
""" """
Deserialize this component from a dictionary. Deserialize this component from a dictionary.
""" """
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data["init_parameters"] init_params = data["init_parameters"]
if init_params["device"] is not None: if init_params["device"] is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["device"] = ComponentDevice.from_dict(init_params["device"])

View File

@ -1,11 +1,11 @@
import json import json
import logging import logging
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
import os
import requests import requests
from haystack import Document, component, default_to_dict, ComponentError from haystack import Document, component, default_to_dict, ComponentError, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,43 +27,48 @@ class SearchApiWebSearch:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("SEARCHAPI_API_KEY"),
top_k: Optional[int] = 10, top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None, allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None, search_params: Optional[Dict[str, Any]] = None,
): ):
""" """
:param api_key: API key for the SearchApi API. It can be :param api_key: API key for the SearchApi API
explicitly provided or automatically read from the
environment variable SEARCHAPI_API_KEY (recommended).
:param top_k: Number of documents to return. :param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to. :param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SearchApi API. :param search_params: Additional parameters passed to the SearchApi API.
For example, you can set 'num' to 100 to increase the number of search results. For example, you can set 'num' to 100 to increase the number of search results.
See the [SearchApi website](https://www.searchapi.io/) for more details. See the [SearchApi website](https://www.searchapi.io/) for more details.
""" """
api_key = api_key or os.environ.get("SEARCHAPI_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"SearchApiWebSearch expects an API key. "
"Set the SEARCHAPI_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.api_key = api_key self.api_key = api_key
self.top_k = top_k self.top_k = top_k
self.allowed_domains = allowed_domains self.allowed_domains = allowed_domains
self.search_params = search_params or {} self.search_params = search_params or {}
# Ensure that the API key is resolved.
_ = self.api_key.resolve_value()
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
Serialize this component to a dictionary. Serialize this component to a dictionary.
""" """
return default_to_dict( return default_to_dict(
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params self,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
api_key=self.api_key.to_dict(),
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SearchApiWebSearch":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str]) @component.output_types(documents=List[Document], links=List[str])
def run(self, query: str): def run(self, query: str):
""" """
@ -74,7 +79,7 @@ class SearchApiWebSearch:
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else "" query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
payload = json.dumps({"q": query_prepend + " " + query, **self.search_params}) payload = json.dumps({"q": query_prepend + " " + query, **self.search_params})
headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "Haystack"} headers = {"Authorization": f"Bearer {self.api_key.resolve_value()}", "X-SearchApi-Source": "Haystack"}
try: try:
response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90) response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90)

View File

@ -1,11 +1,11 @@
import json import json
import logging import logging
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
import os
import requests import requests
from haystack import Document, component, default_to_dict, ComponentError from haystack import Document, component, default_to_dict, ComponentError, default_from_dict
from haystack.utils import Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,43 +27,47 @@ class SerperDevWebSearch:
def __init__( def __init__(
self, self,
api_key: Optional[str] = None, api_key: Secret = Secret.from_env_var("SERPERDEV_API_KEY"),
top_k: Optional[int] = 10, top_k: Optional[int] = 10,
allowed_domains: Optional[List[str]] = None, allowed_domains: Optional[List[str]] = None,
search_params: Optional[Dict[str, Any]] = None, search_params: Optional[Dict[str, Any]] = None,
): ):
""" """
:param api_key: API key for the SerperDev API. It can be :param api_key: API key for the SerperDev API.
explicitly provided or automatically read from the
environment variable SERPERDEV_API_KEY (recommended).
:param top_k: Number of documents to return. :param top_k: Number of documents to return.
:param allowed_domains: List of domains to limit the search to. :param allowed_domains: List of domains to limit the search to.
:param search_params: Additional parameters passed to the SerperDev API. :param search_params: Additional parameters passed to the SerperDev API.
For example, you can set 'num' to 20 to increase the number of search results. For example, you can set 'num' to 20 to increase the number of search results.
See the [Serper Dev website](https://serper.dev/) for more details. See the [Serper Dev website](https://serper.dev/) for more details.
""" """
api_key = api_key or os.environ.get("SERPERDEV_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"SerperDevWebSearch expects an API key. "
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
self.api_key = api_key self.api_key = api_key
self.top_k = top_k self.top_k = top_k
self.allowed_domains = allowed_domains self.allowed_domains = allowed_domains
self.search_params = search_params or {} self.search_params = search_params or {}
# Ensure that the API key is resolved.
_ = self.api_key.resolve_value()
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
Serialize this component to a dictionary. Serialize this component to a dictionary.
""" """
return default_to_dict( return default_to_dict(
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params self,
top_k=self.top_k,
allowed_domains=self.allowed_domains,
search_params=self.search_params,
api_key=self.api_key.to_dict(),
) )
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch":
"""
Deserialize this component from a dictionary.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
@component.output_types(documents=List[Document], links=List[str]) @component.output_types(documents=List[Document], links=List[str])
def run(self, query: str): def run(self, query: str):
""" """
@ -76,10 +80,10 @@ class SerperDevWebSearch:
payload = json.dumps( payload = json.dumps(
{"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **self.search_params} {"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **self.search_params}
) )
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} headers = {"X-API-KEY": self.api_key.resolve_value(), "Content-Type": "application/json"}
try: try:
response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) # type: ignore
response.raise_for_status() # Will raise an HTTPError for bad responses response.raise_for_status() # Will raise an HTTPError for bad responses
except requests.Timeout: except requests.Timeout:
raise TimeoutError(f"Request to {self.__class__.__name__} timed out.") raise TimeoutError(f"Request to {self.__class__.__name__} timed out.")

View File

@ -2,4 +2,4 @@ from haystack.utils.expit import expit
from haystack.utils.requests_utils import request_with_retry from haystack.utils.requests_utils import request_with_retry
from haystack.utils.filters import document_matches_filter from haystack.utils.filters import document_matches_filter
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap
from haystack.utils.auth import Secret from haystack.utils.auth import Secret, deserialize_secrets_inplace

View File

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
import os import os
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, Iterable, List, Optional, Union
from dataclasses import dataclass from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -193,3 +193,21 @@ class EnvVarSecret(Secret):
if out is None and self._strict: if out is None and self._strict:
raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}") raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}")
return out return out
def deserialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False):
"""
Deserialize secrets in a dictionary inplace.
:param data:
The dictionary with the serialized data.
:param keys:
The keys of the secrets to deserialize.
:param recursive:
Whether to recursively deserialize nested dictionaries.
"""
for k, v in data.items():
if isinstance(v, dict) and recursive:
deserialize_secrets_inplace(v, keys)
elif k in keys and v is not None:
data[k] = Secret.from_dict(v)

View File

@ -4,3 +4,37 @@ features:
Expose a `Secret` type to provide consistent API for any component that requires secrets for authentication. Expose a `Secret` type to provide consistent API for any component that requires secrets for authentication.
Currently supports string tokens and environment variables. Token-based secrets are automatically Currently supports string tokens and environment variables. Token-based secrets are automatically
prevented from being serialized to disk (to prevent accidental leakage of secrets). prevented from being serialized to disk (to prevent accidental leakage of secrets).
```python
@component
class MyComponent:
def __init__(self, api_key: Optional[Secret] = None, **kwargs):
self.api_key = api_key
self.backend = None
def warm_up(self):
# Call resolve_value to yield a single result. The semantics of the result is policy-dependent.
# Currently, all supported policies will return a single string token.
self.backend = SomeBackend(api_key=self.api_key.resolve_value() if self.api_key else None, ...)
def to_dict(self):
# Serialize the policy like any other (custom) data. If the policy is token-based, it will
# raise an error.
return default_to_dict(self, api_key=self.api_key.to_dict() if self.api_key else None, ...)
@classmethod
def from_dict(cls, data):
# Deserialize the policy data before passing it to the generic from_dict function.
api_key_data = data["init_parameters"]["api_key"]
api_key = Secret.from_dict(api_key_data) if api_key_data is not None else None
data["init_parameters"]["api_key"] = api_key
return default_from_dict(cls, data)
# No authentication.
component = MyComponent(api_key=None)
# Token based authentication
component = MyComponent(api_key=Secret.from_token("sk-randomAPIkeyasdsa32ekasd32e"))
component.to_dict() # Error! Can't serialize authentication tokens
# Environment variable based authentication
component = MyComponent(api_key=Secret.from_env("OPENAI_API_KEY"))
component.to_dict() # This is fine
```

View File

@ -0,0 +1,8 @@
---
upgrade:
- |
Update secret handling for components using the `Secret` type. The following components are affected:
`RemoteWhisperTranscriber`, `AzureOCRDocumentConverter`, `AzureOpenAIDocumentEmbedder`, `AzureOpenAITextEmbedder`, `HuggingFaceTEIDocumentEmbedder`, `HuggingFaceTEITextEmbedder`, `OpenAIDocumentEmbedder`, `SentenceTransformersDocumentEmbedder`, `SentenceTransformersTextEmbedder`, `AzureOpenAIGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceLocalChatGenerator`, `HuggingFaceTGIChatGenerator`, `OpenAIChatGenerator`, `HuggingFaceLocalGenerator`, `HuggingFaceTGIGenerator`, `OpenAIGenerator`, `TransformersSimilarityRanker`, `SearchApiWebSearch`, `SerperDevWebSearch`
The default init parameters for `api_key`, `token`, `azure_ad_token` have been adjusted to use environment variables wherever possible. The `azure_ad_token_provider` parameter has been removed from Azure-based components. Components based on Hugging
Face are now required to either use a token or an environment variable if authentication is required - The on-disk local token file is no longer supported.

View File

@ -1,29 +1,29 @@
import os import os
import pytest import pytest
from openai import OpenAIError
from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream
from haystack.utils import Secret
class TestRemoteWhisperTranscriber: class TestRemoteWhisperTranscriber:
def test_init_no_key(self, monkeypatch): def test_init_no_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
RemoteWhisperTranscriber(api_key=None) RemoteWhisperTranscriber()
def test_init_key_env_var(self, monkeypatch): def test_init_key_env_var(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key") monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
t = RemoteWhisperTranscriber(api_key=None) t = RemoteWhisperTranscriber()
assert t.client.api_key == "test_api_key" assert t.client.api_key == "test_api_key"
def test_init_key_module_env_and_global_var(self, monkeypatch): def test_init_key_module_env_and_global_var(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2") monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2")
t = RemoteWhisperTranscriber(api_key=None) t = RemoteWhisperTranscriber()
assert t.client.api_key == "test_api_key_2" assert t.client.api_key == "test_api_key_2"
def test_init_default(self): def test_init_default(self):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key") transcriber = RemoteWhisperTranscriber(api_key=Secret.from_token("test_api_key"))
assert transcriber.client.api_key == "test_api_key" assert transcriber.client.api_key == "test_api_key"
assert transcriber.model == "whisper-1" assert transcriber.model == "whisper-1"
assert transcriber.organization is None assert transcriber.organization is None
@ -31,7 +31,7 @@ class TestRemoteWhisperTranscriber:
def test_init_custom_parameters(self): def test_init_custom_parameters(self):
transcriber = RemoteWhisperTranscriber( transcriber = RemoteWhisperTranscriber(
api_key="test_api_key", api_key=Secret.from_token("test_api_key"),
model="whisper-1", model="whisper-1",
organization="test-org", organization="test-org",
api_base_url="test_api_url", api_base_url="test_api_url",
@ -42,6 +42,7 @@ class TestRemoteWhisperTranscriber:
) )
assert transcriber.model == "whisper-1" assert transcriber.model == "whisper-1"
assert transcriber.api_key == Secret.from_token("test_api_key")
assert transcriber.organization == "test-org" assert transcriber.organization == "test-org"
assert transcriber.api_base_url == "test_api_url" assert transcriber.api_base_url == "test_api_url"
assert transcriber.whisper_params == { assert transcriber.whisper_params == {
@ -51,12 +52,14 @@ class TestRemoteWhisperTranscriber:
"temperature": "0.5", "temperature": "0.5",
} }
def test_to_dict_default_parameters(self): def test_to_dict_default_parameters(self, monkeypatch):
transcriber = RemoteWhisperTranscriber(api_key="test_api_key") monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
transcriber = RemoteWhisperTranscriber()
data = transcriber.to_dict() data = transcriber.to_dict()
assert data == { assert data == {
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber", "type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "whisper-1", "model": "whisper-1",
"api_base_url": None, "api_base_url": None,
"organization": None, "organization": None,
@ -64,9 +67,10 @@ class TestRemoteWhisperTranscriber:
}, },
} }
def test_to_dict_with_custom_init_parameters(self): def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
transcriber = RemoteWhisperTranscriber( transcriber = RemoteWhisperTranscriber(
api_key="test_api_key", api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="whisper-1", model="whisper-1",
organization="test-org", organization="test-org",
api_base_url="test_api_url", api_base_url="test_api_url",
@ -79,6 +83,7 @@ class TestRemoteWhisperTranscriber:
assert data == { assert data == {
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber", "type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "whisper-1", "model": "whisper-1",
"organization": "test-org", "organization": "test-org",
"api_base_url": "test_api_url", "api_base_url": "test_api_url",
@ -142,6 +147,7 @@ class TestRemoteWhisperTranscriber:
data = { data = {
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber", "type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "whisper-1", "model": "whisper-1",
"api_base_url": "https://api.openai.com/v1", "api_base_url": "https://api.openai.com/v1",
"organization": None, "organization": None,
@ -149,7 +155,7 @@ class TestRemoteWhisperTranscriber:
}, },
} }
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
RemoteWhisperTranscriber.from_dict(data) RemoteWhisperTranscriber.from_dict(data)
@pytest.mark.skipif( @pytest.mark.skipif(

View File

@ -5,6 +5,7 @@ import pytest
from haystack.components.converters.azure import AzureOCRDocumentConverter from haystack.components.converters.azure import AzureOCRDocumentConverter
from haystack.dataclasses import ByteStream from haystack.dataclasses import ByteStream
from haystack.utils import Secret
class TestAzureOCRDocumentConverter: class TestAzureOCRDocumentConverter:
@ -13,15 +14,23 @@ class TestAzureOCRDocumentConverter:
with pytest.raises(ValueError): with pytest.raises(ValueError):
AzureOCRDocumentConverter(endpoint="test_endpoint") AzureOCRDocumentConverter(endpoint="test_endpoint")
def test_to_dict(self): @patch("haystack.utils.auth.EnvVarSecret.resolve_value")
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key") def test_to_dict(self, mock_resolve_value):
mock_resolve_value.return_value = "test_api_key"
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.converters.azure.AzureOCRDocumentConverter", "type": "haystack.components.converters.azure.AzureOCRDocumentConverter",
"init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"}, "init_parameters": {
"api_key": {"env_vars": ["AZURE_AI_API_KEY"], "strict": True, "type": "env_var"},
"endpoint": "test_endpoint",
"model_id": "prebuilt-read",
},
} }
def test_run(self, test_files_path): @patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_run(self, mock_resolve_value, test_files_path):
mock_resolve_value.return_value = "test_api_key"
with patch("haystack.components.converters.azure.DocumentAnalysisClient") as mock_azure_client: with patch("haystack.components.converters.azure.DocumentAnalysisClient") as mock_azure_client:
mock_result = Mock(pages=[Mock(lines=[Mock(content="mocked line 1"), Mock(content="mocked line 2")])]) mock_result = Mock(pages=[Mock(lines=[Mock(content="mocked line 1"), Mock(content="mocked line 2")])])
mock_result.to_dict.return_value = { mock_result.to_dict.return_value = {
@ -32,7 +41,7 @@ class TestAzureOCRDocumentConverter:
} }
mock_azure_client.return_value.begin_analyze_document.return_value.result.return_value = mock_result mock_azure_client.return_value.begin_analyze_document.return_value.result.return_value = mock_result
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key") component = AzureOCRDocumentConverter(endpoint="test_endpoint")
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"]) output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
document = output["documents"][0] document = output["documents"][0]
assert document.content == "mocked line 1\nmocked line 2\n\f" assert document.content == "mocked line 1\nmocked line 2\n\f"
@ -44,11 +53,12 @@ class TestAzureOCRDocumentConverter:
"pages": [{"lines": [{"content": "mocked line 1"}, {"content": "mocked line 2"}]}], "pages": [{"lines": [{"content": "mocked line 1"}, {"content": "mocked line 2"}]}],
} }
def test_run_with_meta(self, test_files_path): @patch("haystack.utils.auth.EnvVarSecret.resolve_value")
def test_run_with_meta(self, mock_resolve_value, test_files_path):
mock_resolve_value.return_value = "test_api_key"
bytestream = ByteStream(data=b"test", meta={"author": "test_author", "language": "en"}) bytestream = ByteStream(data=b"test", meta={"author": "test_author", "language": "en"})
with patch("haystack.components.converters.azure.DocumentAnalysisClient"): with patch("haystack.components.converters.azure.DocumentAnalysisClient"):
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key") component = AzureOCRDocumentConverter(endpoint="test_endpoint")
output = component.run( output = component.run(
sources=[bytestream, test_files_path / "pdf" / "sample_pdf_1.pdf"], meta={"language": "it"} sources=[bytestream, test_files_path / "pdf" / "sample_pdf_1.pdf"], meta={"language": "it"}
) )
@ -63,7 +73,7 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available") @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_run_with_pdf_file(self, test_files_path): def test_run_with_pdf_file(self, test_files_path):
component = AzureOCRDocumentConverter( component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"] endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
) )
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"]) output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
documents = output["documents"] documents = output["documents"]
@ -77,7 +87,7 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available") @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_with_image_file(self, test_files_path): def test_with_image_file(self, test_files_path):
component = AzureOCRDocumentConverter( component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"] endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
) )
output = component.run(sources=[test_files_path / "images" / "haystack-logo.png"]) output = component.run(sources=[test_files_path / "images" / "haystack-logo.png"])
documents = output["documents"] documents = output["documents"]
@ -90,7 +100,7 @@ class TestAzureOCRDocumentConverter:
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available") @pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
def test_run_with_docx_file(self, test_files_path): def test_run_with_docx_file(self, test_files_path):
component = AzureOCRDocumentConverter( component = AzureOCRDocumentConverter(
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"] endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
) )
output = component.run(sources=[test_files_path / "docx" / "sample_docx.docx"]) output = component.run(sources=[test_files_path / "docx" / "sample_docx.docx"])
documents = output["documents"] documents = output["documents"]

View File

@ -19,14 +19,15 @@ class TestAzureOpenAIDocumentEmbedder:
assert embedder.meta_fields_to_embed == [] assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n" assert embedder.embedding_separator == "\n"
def test_to_dict(self): def test_to_dict(self, monkeypatch):
component = AzureOpenAIDocumentEmbedder( monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/" component = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder", "type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15", "api_version": "2023-05-15",
"azure_deployment": "text-embedding-ada-002", "azure_deployment": "text-embedding-ada-002",
"azure_endpoint": "https://example-resource.azure.openai.com/", "azure_endpoint": "https://example-resource.azure.openai.com/",

View File

@ -16,14 +16,15 @@ class TestAzureOpenAITextEmbedder:
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
def test_to_dict(self): def test_to_dict(self, monkeypatch):
component = AzureOpenAITextEmbedder( monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/" component = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder", "type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "text-embedding-ada-002", "azure_deployment": "text-embedding-ada-002",
"organization": None, "organization": None,
"azure_endpoint": "https://example-resource.azure.openai.com/", "azure_endpoint": "https://example-resource.azure.openai.com/",

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from haystack.utils.auth import Secret
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
from haystack.dataclasses import Document from haystack.dataclasses import Document
@ -29,7 +30,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.model == "BAAI/bge-small-en-v1.5" assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None assert embedder.url is None
assert embedder.token == "fake-api-token" assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
assert embedder.batch_size == 32 assert embedder.batch_size == 32
@ -41,7 +42,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com", url="https://some_embedding_model.com",
token="fake-api-token", token=Secret.from_token("fake-api-token"),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
batch_size=64, batch_size=64,
@ -52,7 +53,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert embedder.model == "sentence-transformers/all-mpnet-base-v2" assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com" assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == "fake-api-token" assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix" assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix" assert embedder.suffix == "suffix"
assert embedder.batch_size == 64 assert embedder.batch_size == 64
@ -78,6 +79,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder", "type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": { "init_parameters": {
"model": "BAAI/bge-small-en-v1.5", "model": "BAAI/bge-small-en-v1.5",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"url": None, "url": None,
"prefix": "", "prefix": "",
"suffix": "", "suffix": "",
@ -92,7 +94,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
component = HuggingFaceTEIDocumentEmbedder( component = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com", url="https://some_embedding_model.com",
token="fake-api-token", token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
batch_size=64, batch_size=64,
@ -106,6 +108,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder", "type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2", "model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com", "url": "https://some_embedding_model.com",
"prefix": "prefix", "prefix": "prefix",
@ -125,7 +128,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com", url="https://some_embedding_model.com",
token="fake-api-token", token=Secret.from_token("fake-api-token"),
meta_fields_to_embed=["meta_field"], meta_fields_to_embed=["meta_field"],
embedding_separator=" | ", embedding_separator=" | ",
) )
@ -146,7 +149,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com", url="https://some_embedding_model.com",
token="fake-api-token", token=Secret.from_token("fake-api-token"),
prefix="my_prefix ", prefix="my_prefix ",
suffix=" my_suffix", suffix=" my_suffix",
) )
@ -168,7 +171,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
mock_embedding_patch.side_effect = mock_embedding_generation mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token" model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
) )
embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2) embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
@ -192,7 +197,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", model="BAAI/bge-small-en-v1.5",
token="fake-api-token", token=Secret.from_token("fake-api-token"),
prefix="prefix ", prefix="prefix ",
suffix=" suffix", suffix=" suffix",
meta_fields_to_embed=["topic"], meta_fields_to_embed=["topic"],
@ -228,7 +233,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", model="BAAI/bge-small-en-v1.5",
token="fake-api-token", token=Secret.from_token("fake-api-token"),
prefix="prefix ", prefix="prefix ",
suffix=" suffix", suffix=" suffix",
meta_fields_to_embed=["topic"], meta_fields_to_embed=["topic"],
@ -252,7 +257,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
def test_run_wrong_input_format(self, mock_check_valid_model): def test_run_wrong_input_format(self, mock_check_valid_model):
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token" model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
) )
# wrong formats # wrong formats
@ -267,7 +274,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
def test_run_on_empty_list(self, mock_check_valid_model): def test_run_on_empty_list(self, mock_check_valid_model):
embedder = HuggingFaceTEIDocumentEmbedder( embedder = HuggingFaceTEIDocumentEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token" model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
) )
empty_list_input = [] empty_list_input = []

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub.utils import RepositoryNotFoundError
from haystack.utils.auth import Secret
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
@ -27,7 +28,7 @@ class TestHuggingFaceTEITextEmbedder:
assert embedder.model == "BAAI/bge-small-en-v1.5" assert embedder.model == "BAAI/bge-small-en-v1.5"
assert embedder.url is None assert embedder.url is None
assert embedder.token == "fake-api-token" assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
@ -35,14 +36,14 @@ class TestHuggingFaceTEITextEmbedder:
embedder = HuggingFaceTEITextEmbedder( embedder = HuggingFaceTEITextEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com", url="https://some_embedding_model.com",
token="fake-api-token", token=Secret.from_token("fake-api-token"),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
) )
assert embedder.model == "sentence-transformers/all-mpnet-base-v2" assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
assert embedder.url == "https://some_embedding_model.com" assert embedder.url == "https://some_embedding_model.com"
assert embedder.token == "fake-api-token" assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix" assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix" assert embedder.suffix == "suffix"
@ -62,14 +63,20 @@ class TestHuggingFaceTEITextEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder", "type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": {"model": "BAAI/bge-small-en-v1.5", "url": None, "prefix": "", "suffix": ""}, "init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "BAAI/bge-small-en-v1.5",
"url": None,
"prefix": "",
"suffix": "",
},
} }
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
component = HuggingFaceTEITextEmbedder( component = HuggingFaceTEITextEmbedder(
model="sentence-transformers/all-mpnet-base-v2", model="sentence-transformers/all-mpnet-base-v2",
url="https://some_embedding_model.com", url="https://some_embedding_model.com",
token="fake-api-token", token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
) )
@ -79,6 +86,7 @@ class TestHuggingFaceTEITextEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder", "type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "sentence-transformers/all-mpnet-base-v2", "model": "sentence-transformers/all-mpnet-base-v2",
"url": "https://some_embedding_model.com", "url": "https://some_embedding_model.com",
"prefix": "prefix", "prefix": "prefix",
@ -91,7 +99,10 @@ class TestHuggingFaceTEITextEmbedder:
mock_embedding_patch.side_effect = mock_embedding_generation mock_embedding_patch.side_effect = mock_embedding_generation
embedder = HuggingFaceTEITextEmbedder( embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", token="fake-api-token", prefix="prefix ", suffix=" suffix" model="BAAI/bge-small-en-v1.5",
token=Secret.from_token("fake-api-token"),
prefix="prefix ",
suffix=" suffix",
) )
result = embedder.run(text="The food was delicious") result = embedder.run(text="The food was delicious")
@ -103,7 +114,9 @@ class TestHuggingFaceTEITextEmbedder:
def test_run_wrong_input_format(self, mock_check_valid_model): def test_run_wrong_input_format(self, mock_check_valid_model):
embedder = HuggingFaceTEITextEmbedder( embedder = HuggingFaceTEITextEmbedder(
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token" model="BAAI/bge-small-en-v1.5",
url="https://some_embedding_model.com",
token=Secret.from_token("fake-api-token"),
) )
list_integers_input = [1, 2, 3] list_integers_input = [1, 2, 3]

View File

@ -1,9 +1,9 @@
import os import os
from typing import List from typing import List
from haystack.utils.auth import Secret
import numpy as np import numpy as np
import pytest import pytest
from openai import OpenAIError
from haystack import Document from haystack import Document
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
@ -37,7 +37,7 @@ class TestOpenAIDocumentEmbedder:
def test_init_with_parameters(self): def test_init_with_parameters(self):
embedder = OpenAIDocumentEmbedder( embedder = OpenAIDocumentEmbedder(
api_key="fake-api-key", api_key=Secret.from_token("fake-api-key"),
model="model", model="model",
organization="my-org", organization="my-org",
prefix="prefix", prefix="prefix",
@ -58,15 +58,17 @@ class TestOpenAIDocumentEmbedder:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIDocumentEmbedder() OpenAIDocumentEmbedder()
def test_to_dict(self): def test_to_dict(self, monkeypatch):
component = OpenAIDocumentEmbedder(api_key="fake-api-key") monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
component = OpenAIDocumentEmbedder()
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder", "type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"api_base_url": None, "api_base_url": None,
"model": "text-embedding-ada-002", "model": "text-embedding-ada-002",
"organization": None, "organization": None,
@ -79,9 +81,10 @@ class TestOpenAIDocumentEmbedder:
}, },
} }
def test_to_dict_with_custom_init_parameters(self): def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "fake-api-key")
component = OpenAIDocumentEmbedder( component = OpenAIDocumentEmbedder(
api_key="fake-api-key", api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="model", model="model",
organization="my-org", organization="my-org",
prefix="prefix", prefix="prefix",
@ -95,6 +98,7 @@ class TestOpenAIDocumentEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder", "type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"api_base_url": None, "api_base_url": None,
"model": "model", "model": "model",
"organization": "my-org", "organization": "my-org",
@ -113,7 +117,7 @@ class TestOpenAIDocumentEmbedder:
] ]
embedder = OpenAIDocumentEmbedder( embedder = OpenAIDocumentEmbedder(
api_key="fake-api-key", meta_fields_to_embed=["meta_field"], embedding_separator=" | " api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | "
) )
prepared_texts = embedder._prepare_texts_to_embed(documents) prepared_texts = embedder._prepare_texts_to_embed(documents)
@ -130,7 +134,9 @@ class TestOpenAIDocumentEmbedder:
def test_prepare_texts_to_embed_w_suffix(self): def test_prepare_texts_to_embed_w_suffix(self):
documents = [Document(content=f"document number {i}") for i in range(5)] documents = [Document(content=f"document number {i}") for i in range(5)]
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix") embedder = OpenAIDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix"
)
prepared_texts = embedder._prepare_texts_to_embed(documents) prepared_texts = embedder._prepare_texts_to_embed(documents)
@ -143,7 +149,7 @@ class TestOpenAIDocumentEmbedder:
] ]
def test_run_wrong_input_format(self): def test_run_wrong_input_format(self):
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key") embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
# wrong formats # wrong formats
string_input = "text" string_input = "text"
@ -156,7 +162,7 @@ class TestOpenAIDocumentEmbedder:
embedder.run(documents=list_integers_input) embedder.run(documents=list_integers_input)
def test_run_on_empty_list(self): def test_run_on_empty_list(self):
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key") embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
empty_list_input = [] empty_list_input = []
result = embedder.run(documents=empty_list_input) result = embedder.run(documents=empty_list_input)

View File

@ -1,7 +1,7 @@
import os import os
from haystack.utils.auth import Secret
import pytest import pytest
from openai import OpenAIError
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
@ -19,7 +19,11 @@ class TestOpenAITextEmbedder:
def test_init_with_parameters(self): def test_init_with_parameters(self):
embedder = OpenAITextEmbedder( embedder = OpenAITextEmbedder(
api_key="fake-api-key", model="model", organization="fake-organization", prefix="prefix", suffix="suffix" api_key=Secret.from_token("fake-api-key"),
model="model",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
) )
assert embedder.client.api_key == "fake-api-key" assert embedder.client.api_key == "fake-api-key"
assert embedder.model == "model" assert embedder.model == "model"
@ -29,25 +33,38 @@ class TestOpenAITextEmbedder:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAITextEmbedder() OpenAITextEmbedder()
def test_to_dict(self): def test_to_dict(self, monkeypatch):
component = OpenAITextEmbedder(api_key="fake-api-key") monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
component = OpenAITextEmbedder()
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder", "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
"init_parameters": {"model": "text-embedding-ada-002", "organization": None, "prefix": "", "suffix": ""}, "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "text-embedding-ada-002",
"organization": None,
"prefix": "",
"suffix": "",
},
} }
def test_to_dict_with_custom_init_parameters(self): def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "fake-api-key")
component = OpenAITextEmbedder( component = OpenAITextEmbedder(
api_key="fake-api-key", model="model", organization="fake-organization", prefix="prefix", suffix="suffix" api_key=Secret.from_env_var("ENV_VAR", strict=False),
model="model",
organization="fake-organization",
prefix="prefix",
suffix="suffix",
) )
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder", "type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
"organization": "fake-organization", "organization": "fake-organization",
"prefix": "prefix", "prefix": "prefix",
@ -56,7 +73,7 @@ class TestOpenAITextEmbedder:
} }
def test_run_wrong_input_format(self): def test_run_wrong_input_format(self):
embedder = OpenAITextEmbedder(api_key="fake-api-key") embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key"))
list_integers_input = [1, 2, 3] list_integers_input = [1, 2, 3]

View File

@ -1,6 +1,7 @@
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import pytest import pytest
import numpy as np import numpy as np
from haystack.utils.auth import Secret
from haystack import Document from haystack import Document
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
@ -11,7 +12,7 @@ class TestSentenceTransformersDocumentEmbedder:
embedder = SentenceTransformersDocumentEmbedder(model="model") embedder = SentenceTransformersDocumentEmbedder(model="model")
assert embedder.model == "model" assert embedder.model == "model"
assert embedder.device == "cpu" assert embedder.device == "cpu"
assert embedder.token is None assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
assert embedder.batch_size == 32 assert embedder.batch_size == 32
@ -24,7 +25,7 @@ class TestSentenceTransformersDocumentEmbedder:
embedder = SentenceTransformersDocumentEmbedder( embedder = SentenceTransformersDocumentEmbedder(
model="model", model="model",
device="cuda", device="cuda",
token=True, token=Secret.from_token("fake-api-token"),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
batch_size=64, batch_size=64,
@ -35,7 +36,7 @@ class TestSentenceTransformersDocumentEmbedder:
) )
assert embedder.model == "model" assert embedder.model == "model"
assert embedder.device == "cuda" assert embedder.device == "cuda"
assert embedder.token is True assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix" assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix" assert embedder.suffix == "suffix"
assert embedder.batch_size == 64 assert embedder.batch_size == 64
@ -52,7 +53,7 @@ class TestSentenceTransformersDocumentEmbedder:
"init_parameters": { "init_parameters": {
"model": "model", "model": "model",
"device": "cpu", "device": "cpu",
"token": None, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"prefix": "", "prefix": "",
"suffix": "", "suffix": "",
"batch_size": 32, "batch_size": 32,
@ -67,7 +68,7 @@ class TestSentenceTransformersDocumentEmbedder:
component = SentenceTransformersDocumentEmbedder( component = SentenceTransformersDocumentEmbedder(
model="model", model="model",
device="cuda", device="cuda",
token="the-token", token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
batch_size=64, batch_size=64,
@ -83,7 +84,7 @@ class TestSentenceTransformersDocumentEmbedder:
"init_parameters": { "init_parameters": {
"model": "model", "model": "model",
"device": "cuda", "device": "cuda",
"token": None, # the token is not serialized "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"prefix": "prefix", "prefix": "prefix",
"suffix": "suffix", "suffix": "suffix",
"batch_size": 64, "batch_size": 64,
@ -98,10 +99,10 @@ class TestSentenceTransformersDocumentEmbedder:
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
) )
def test_warmup(self, mocked_factory): def test_warmup(self, mocked_factory):
embedder = SentenceTransformersDocumentEmbedder(model="model") embedder = SentenceTransformersDocumentEmbedder(model="model", token=None)
mocked_factory.get_embedding_backend.assert_not_called() mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up() embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", use_auth_token=None) mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
@patch( @patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory" "haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -3,6 +3,7 @@ import pytest
from haystack.components.embedders.backends.sentence_transformers_backend import ( from haystack.components.embedders.backends.sentence_transformers_backend import (
_SentenceTransformersEmbeddingBackendFactory, _SentenceTransformersEmbeddingBackendFactory,
) )
from haystack.utils.auth import Secret
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") @patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
@ -22,10 +23,10 @@ def test_factory_behavior(mock_sentence_transformer):
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer") @patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer): def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", use_auth_token="my_token" model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
) )
mock_sentence_transformer.assert_called_once_with( mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="my_token" model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
) )

View File

@ -1,5 +1,6 @@
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import pytest import pytest
from haystack.utils.auth import Secret
import numpy as np import numpy as np
@ -11,7 +12,7 @@ class TestSentenceTransformersTextEmbedder:
embedder = SentenceTransformersTextEmbedder(model="model") embedder = SentenceTransformersTextEmbedder(model="model")
assert embedder.model == "model" assert embedder.model == "model"
assert embedder.device == "cpu" assert embedder.device == "cpu"
assert embedder.token is None assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.prefix == "" assert embedder.prefix == ""
assert embedder.suffix == "" assert embedder.suffix == ""
assert embedder.batch_size == 32 assert embedder.batch_size == 32
@ -22,7 +23,7 @@ class TestSentenceTransformersTextEmbedder:
embedder = SentenceTransformersTextEmbedder( embedder = SentenceTransformersTextEmbedder(
model="model", model="model",
device="cuda", device="cuda",
token=True, token=Secret.from_token("fake-api-token"),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
batch_size=64, batch_size=64,
@ -31,7 +32,7 @@ class TestSentenceTransformersTextEmbedder:
) )
assert embedder.model == "model" assert embedder.model == "model"
assert embedder.device == "cuda" assert embedder.device == "cuda"
assert embedder.token is True assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.prefix == "prefix" assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix" assert embedder.suffix == "suffix"
assert embedder.batch_size == 64 assert embedder.batch_size == 64
@ -44,9 +45,9 @@ class TestSentenceTransformersTextEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
"device": "cpu", "device": "cpu",
"token": None,
"prefix": "", "prefix": "",
"suffix": "", "suffix": "",
"batch_size": 32, "batch_size": 32,
@ -59,7 +60,7 @@ class TestSentenceTransformersTextEmbedder:
component = SentenceTransformersTextEmbedder( component = SentenceTransformersTextEmbedder(
model="model", model="model",
device="cuda", device="cuda",
token=True, token=Secret.from_env_var("ENV_VAR", strict=False),
prefix="prefix", prefix="prefix",
suffix="suffix", suffix="suffix",
batch_size=64, batch_size=64,
@ -70,9 +71,9 @@ class TestSentenceTransformersTextEmbedder:
assert data == { assert data == {
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder", "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"model": "model", "model": "model",
"device": "cuda", "device": "cuda",
"token": True,
"prefix": "prefix", "prefix": "prefix",
"suffix": "suffix", "suffix": "suffix",
"batch_size": 64, "batch_size": 64,
@ -82,30 +83,18 @@ class TestSentenceTransformersTextEmbedder:
} }
def test_to_dict_not_serialize_token(self): def test_to_dict_not_serialize_token(self):
component = SentenceTransformersTextEmbedder(model="model", token="awesome-token") component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))
data = component.to_dict() with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
assert data == { component.to_dict()
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
"init_parameters": {
"model": "model",
"device": "cpu",
"token": None,
"prefix": "",
"suffix": "",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
},
}
@patch( @patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
) )
def test_warmup(self, mocked_factory): def test_warmup(self, mocked_factory):
embedder = SentenceTransformersTextEmbedder(model="model") embedder = SentenceTransformersTextEmbedder(model="model", token=None)
mocked_factory.get_embedding_backend.assert_not_called() mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up() embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", use_auth_token=None) mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
@patch( @patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory" "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"

View File

@ -6,11 +6,13 @@ from openai import OpenAIError
from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.components.generators.chat import AzureOpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
from haystack.dataclasses import ChatMessage from haystack.dataclasses import ChatMessage
from haystack.utils.auth import Secret
class TestOpenAIChatGenerator: class TestOpenAIChatGenerator:
def test_init_default(self): def test_init_default(self, monkeypatch):
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", api_key="test-api-key") monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.azure_deployment == "gpt-35-turbo" assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is None assert component.streaming_callback is None
@ -18,13 +20,14 @@ class TestOpenAIChatGenerator:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint") AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
def test_init_with_parameters(self): def test_init_with_parameters(self):
component = AzureOpenAIChatGenerator( component = AzureOpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"),
azure_endpoint="some-non-existing-endpoint", azure_endpoint="some-non-existing-endpoint",
api_key="test-api-key",
streaming_callback=print_streaming_chunk, streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
) )
@ -33,12 +36,15 @@ class TestOpenAIChatGenerator:
assert component.streaming_callback is print_streaming_chunk assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self): def test_to_dict_default(self, monkeypatch):
component = AzureOpenAIChatGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint") monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15", "api_version": "2023-05-15",
"azure_endpoint": "some-non-existing-endpoint", "azure_endpoint": "some-non-existing-endpoint",
"azure_deployment": "gpt-35-turbo", "azure_deployment": "gpt-35-turbo",
@ -48,9 +54,11 @@ class TestOpenAIChatGenerator:
}, },
} }
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = AzureOpenAIChatGenerator( component = AzureOpenAIChatGenerator(
api_key="test-api-key", api_key=Secret.from_env_var("ENV_VAR", strict=False),
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint", azure_endpoint="some-non-existing-endpoint",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
) )
@ -58,6 +66,8 @@ class TestOpenAIChatGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator", "type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"},
"api_version": "2023-05-15", "api_version": "2023-05-15",
"azure_endpoint": "some-non-existing-endpoint", "azure_endpoint": "some-non-existing-endpoint",
"azure_deployment": "gpt-35-turbo", "azure_deployment": "gpt-35-turbo",

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, Mock from unittest.mock import patch, Mock
from haystack.utils.auth import Secret
import pytest import pytest
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -56,7 +57,7 @@ class TestHuggingFaceLocalChatGenerator:
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2", model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation", task="text2text-generation",
token="test-token", token=Secret.from_token("test-token"),
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
) )
@ -72,6 +73,7 @@ class TestHuggingFaceLocalChatGenerator:
model="mistralai/Mistral-7B-Instruct-v0.2", model="mistralai/Mistral-7B-Instruct-v0.2",
task="text2text-generation", task="text2text-generation",
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
token=None,
) )
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
@ -82,7 +84,9 @@ class TestHuggingFaceLocalChatGenerator:
} }
def test_init_task_parameter(self): def test_init_task_parameter(self):
generator = HuggingFaceLocalChatGenerator(task="text2text-generation", device=ComponentDevice.from_str("cpu")) generator = HuggingFaceLocalChatGenerator(
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
)
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "HuggingFaceH4/zephyr-7b-beta", "model": "HuggingFaceH4/zephyr-7b-beta",
@ -93,7 +97,9 @@ class TestHuggingFaceLocalChatGenerator:
def test_init_task_in_huggingface_pipeline_kwargs(self): def test_init_task_in_huggingface_pipeline_kwargs(self):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu") huggingface_pipeline_kwargs={"task": "text2text-generation"},
device=ComponentDevice.from_str("cpu"),
token=None,
) )
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
@ -105,7 +111,7 @@ class TestHuggingFaceLocalChatGenerator:
def test_init_task_inferred_from_model_name(self, model_info_mock): def test_init_task_inferred_from_model_name(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu") model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu"), token=None
) )
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
@ -122,7 +128,7 @@ class TestHuggingFaceLocalChatGenerator:
def test_to_dict(self, model_info_mock): def test_to_dict(self, model_info_mock):
generator = HuggingFaceLocalChatGenerator( generator = HuggingFaceLocalChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf", model="NousResearch/Llama-2-7b-chat-hf",
token="token", token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5}, generation_kwargs={"n": 5},
stop_words=["stop", "words"], stop_words=["stop", "words"],
streaming_callback=lambda x: x, streaming_callback=lambda x: x,
@ -133,6 +139,7 @@ class TestHuggingFaceLocalChatGenerator:
init_params = result["init_parameters"] init_params = result["init_parameters"]
# Assert that the init_params dictionary contains the expected keys and values # Assert that the init_params dictionary contains the expected keys and values
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf" assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert "token" not in init_params["huggingface_pipeline_kwargs"] assert "token" not in init_params["huggingface_pipeline_kwargs"]
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]} assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock, Mock from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
@ -59,7 +60,7 @@ class TestHuggingFaceTGIChatGenerator:
# Initialize the HuggingFaceTGIChatGenerator object with valid parameters # Initialize the HuggingFaceTGIChatGenerator object with valid parameters
generator = HuggingFaceTGIChatGenerator( generator = HuggingFaceTGIChatGenerator(
model="NousResearch/Llama-2-7b-chat-hf", model="NousResearch/Llama-2-7b-chat-hf",
token="token", token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5}, generation_kwargs={"n": 5},
stop_words=["stop", "words"], stop_words=["stop", "words"],
streaming_callback=lambda x: x, streaming_callback=lambda x: x,
@ -71,7 +72,7 @@ class TestHuggingFaceTGIChatGenerator:
# Assert that the init_params dictionary contains the expected keys and values # Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf" assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
assert init_params["token"] is None assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
def test_from_dict(self, mock_check_valid_model): def test_from_dict(self, mock_check_valid_model):

View File

@ -2,6 +2,7 @@ import os
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
from haystack.utils.auth import Secret
from haystack.components.generators.chat import OpenAIChatGenerator from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators.utils import print_streaming_chunk from haystack.components.generators.utils import print_streaming_chunk
@ -17,8 +18,9 @@ def chat_messages():
class TestOpenAIChatGenerator: class TestOpenAIChatGenerator:
def test_init_default(self): def test_init_default(self, monkeypatch):
component = OpenAIChatGenerator(api_key="test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator()
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.model == "gpt-3.5-turbo" assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None assert component.streaming_callback is None
@ -26,12 +28,12 @@ class TestOpenAIChatGenerator:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIChatGenerator() OpenAIChatGenerator()
def test_init_with_parameters(self): def test_init_with_parameters(self):
component = OpenAIChatGenerator( component = OpenAIChatGenerator(
api_key="test-api-key", api_key=Secret.from_token("test-api-key"),
model="gpt-4", model="gpt-4",
streaming_callback=print_streaming_chunk, streaming_callback=print_streaming_chunk,
api_base_url="test-base-url", api_base_url="test-base-url",
@ -42,12 +44,14 @@ class TestOpenAIChatGenerator:
assert component.streaming_callback is print_streaming_chunk assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self): def test_to_dict_default(self, monkeypatch):
component = OpenAIChatGenerator(api_key="test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator()
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"organization": None, "organization": None,
"streaming_callback": None, "streaming_callback": None,
@ -56,9 +60,10 @@ class TestOpenAIChatGenerator:
}, },
} }
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = OpenAIChatGenerator( component = OpenAIChatGenerator(
api_key="test-api-key", api_key=Secret.from_env_var("ENV_VAR"),
model="gpt-4", model="gpt-4",
streaming_callback=print_streaming_chunk, streaming_callback=print_streaming_chunk,
api_base_url="test-base-url", api_base_url="test-base-url",
@ -68,6 +73,7 @@ class TestOpenAIChatGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"organization": None, "organization": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
@ -76,9 +82,9 @@ class TestOpenAIChatGenerator:
}, },
} }
def test_to_dict_with_lambda_streaming_callback(self): def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIChatGenerator( component = OpenAIChatGenerator(
api_key="test-api-key",
model="gpt-4", model="gpt-4",
streaming_callback=lambda x: x, streaming_callback=lambda x: x,
api_base_url="test-base-url", api_base_url="test-base-url",
@ -88,6 +94,7 @@ class TestOpenAIChatGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"organization": None, "organization": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
@ -96,10 +103,12 @@ class TestOpenAIChatGenerator:
}, },
} }
def test_from_dict(self): def test_from_dict(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
data = { data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
@ -111,12 +120,14 @@ class TestOpenAIChatGenerator:
assert component.streaming_callback is print_streaming_chunk assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url" assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
def test_from_dict_fail_wo_env_var(self, monkeypatch): def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = { data = {
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator", "type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"organization": None, "organization": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
@ -124,11 +135,11 @@ class TestOpenAIChatGenerator:
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
}, },
} }
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIChatGenerator.from_dict(data) OpenAIChatGenerator.from_dict(data)
def test_run(self, chat_messages, mock_chat_completion): def test_run(self, chat_messages, mock_chat_completion):
component = OpenAIChatGenerator() component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages) response = component.run(chat_messages)
# check that the component returns the correct ChatMessage response # check that the component returns the correct ChatMessage response
@ -139,7 +150,9 @@ class TestOpenAIChatGenerator:
assert [isinstance(reply, ChatMessage) for reply in response["replies"]] assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
def test_run_with_params(self, chat_messages, mock_chat_completion): def test_run_with_params(self, chat_messages, mock_chat_completion):
component = OpenAIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5}) component = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
)
response = component.run(chat_messages) response = component.run(chat_messages)
# check that the component calls the OpenAI API with the correct parameters # check that the component calls the OpenAI API with the correct parameters
@ -161,7 +174,9 @@ class TestOpenAIChatGenerator:
nonlocal streaming_callback_called nonlocal streaming_callback_called
streaming_callback_called = True streaming_callback_called = True
component = OpenAIChatGenerator(streaming_callback=streaming_callback) component = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
)
response = component.run(chat_messages) response = component.run(chat_messages)
# check we called the streaming callback # check we called the streaming callback
@ -176,7 +191,7 @@ class TestOpenAIChatGenerator:
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
def test_check_abnormal_completions(self, caplog): def test_check_abnormal_completions(self, caplog):
component = OpenAIChatGenerator(api_key="test-api-key") component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
messages = [ messages = [
ChatMessage.from_assistant( ChatMessage.from_assistant(
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i} "", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
@ -208,7 +223,7 @@ class TestOpenAIChatGenerator:
@pytest.mark.integration @pytest.mark.integration
def test_live_run(self): def test_live_run(self):
chat_messages = [ChatMessage.from_user("What's the capital of France")] chat_messages = [ChatMessage.from_user("What's the capital of France")]
component = OpenAIChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1}) component = OpenAIChatGenerator(generation_kwargs={"n": 1})
results = component.run(chat_messages) results = component.run(chat_messages)
assert len(results["replies"]) == 1 assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0] message: ChatMessage = results["replies"][0]
@ -222,7 +237,7 @@ class TestOpenAIChatGenerator:
) )
@pytest.mark.integration @pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages): def test_live_run_wrong_model(self, chat_messages):
component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) component = OpenAIChatGenerator(model="something-obviously-wrong")
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
component.run(chat_messages) component.run(chat_messages)

View File

@ -1,4 +1,5 @@
import os import os
from haystack.utils.auth import Secret
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
@ -8,8 +9,9 @@ from haystack.components.generators.utils import print_streaming_chunk
class TestAzureOpenAIGenerator: class TestAzureOpenAIGenerator:
def test_init_default(self): def test_init_default(self, monkeypatch):
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint") monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.azure_deployment == "gpt-35-turbo" assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is None assert component.streaming_callback is None
@ -17,28 +19,32 @@ class TestAzureOpenAIGenerator:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False) monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint") AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
def test_init_with_parameters(self): def test_init_with_parameters(self):
component = AzureOpenAIGenerator( component = AzureOpenAIGenerator(
api_key="test-api-key", api_key=Secret.from_token("fake-api-key"),
azure_endpoint="some-non-existing-endpoint", azure_endpoint="some-non-existing-endpoint",
azure_deployment="gpt-35-turbo", azure_deployment="gpt-35-turbo",
streaming_callback=print_streaming_chunk, streaming_callback=print_streaming_chunk,
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
) )
assert component.client.api_key == "test-api-key" assert component.client.api_key == "fake-api-key"
assert component.azure_deployment == "gpt-35-turbo" assert component.azure_deployment == "gpt-35-turbo"
assert component.streaming_callback is print_streaming_chunk assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self): def test_to_dict_default(self, monkeypatch):
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint") monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
component = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.generators.azure.AzureOpenAIGenerator", "type": "haystack.components.generators.azure.AzureOpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
"azure_deployment": "gpt-35-turbo", "azure_deployment": "gpt-35-turbo",
"api_version": "2023-05-15", "api_version": "2023-05-15",
"streaming_callback": None, "streaming_callback": None,
@ -49,9 +55,11 @@ class TestAzureOpenAIGenerator:
}, },
} }
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = AzureOpenAIGenerator( component = AzureOpenAIGenerator(
api_key="test-api-key", api_key=Secret.from_env_var("ENV_VAR", strict=False),
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
azure_endpoint="some-non-existing-endpoint", azure_endpoint="some-non-existing-endpoint",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
) )
@ -60,6 +68,8 @@ class TestAzureOpenAIGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.azure.AzureOpenAIGenerator", "type": "haystack.components.generators.azure.AzureOpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"},
"azure_deployment": "gpt-35-turbo", "azure_deployment": "gpt-35-turbo",
"api_version": "2023-05-15", "api_version": "2023-05-15",
"streaming_callback": None, "streaming_callback": None,

View File

@ -4,14 +4,16 @@ from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
from transformers import PreTrainedTokenizerFast from transformers import PreTrainedTokenizerFast
from haystack.utils.auth import Secret
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
from haystack.utils import ComponentDevice, Device from haystack.utils import ComponentDevice
class TestHuggingFaceLocalGenerator: class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.model_info") @patch("haystack.components.generators.hugging_face_local.model_info")
def test_init_default(self, model_info_mock): def test_init_default(self, model_info_mock, monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)
model_info_mock.return_value.pipeline_tag = "text2text-generation" model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator() generator = HuggingFaceLocalGenerator()
@ -26,30 +28,33 @@ class TestHuggingFaceLocalGenerator:
def test_init_custom_token(self): def test_init_custom_token(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", token="test-token" model="google/flan-t5-base", task="text2text-generation", token=Secret.from_token("fake-api-token")
) )
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
"task": "text2text-generation", "task": "text2text-generation",
"token": "test-token", "token": "fake-api-token",
"device": ComponentDevice.resolve_device(None).to_hf(), "device": ComponentDevice.resolve_device(None).to_hf(),
} }
def test_init_custom_device(self): def test_init_custom_device(self):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", device=ComponentDevice.from_str("cuda:0") model="google/flan-t5-base",
task="text2text-generation",
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_token("fake-api-token"),
) )
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
"task": "text2text-generation", "task": "text2text-generation",
"token": None, "token": "fake-api-token",
"device": "cuda:0", "device": "cuda:0",
} }
def test_init_task_parameter(self): def test_init_task_parameter(self):
generator = HuggingFaceLocalGenerator(task="text2text-generation") generator = HuggingFaceLocalGenerator(task="text2text-generation", token=None)
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
@ -59,7 +64,7 @@ class TestHuggingFaceLocalGenerator:
} }
def test_init_task_in_huggingface_pipeline_kwargs(self): def test_init_task_in_huggingface_pipeline_kwargs(self):
generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"}) generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"}, token=None)
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
@ -71,7 +76,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.model_info") @patch("haystack.components.generators.hugging_face_local.model_info")
def test_init_task_inferred_from_model_name(self, model_info_mock): def test_init_task_inferred_from_model_name(self, model_info_mock):
model_info_mock.return_value.pipeline_tag = "text2text-generation" model_info_mock.return_value.pipeline_tag = "text2text-generation"
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base") generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", token=None)
assert generator.huggingface_pipeline_kwargs == { assert generator.huggingface_pipeline_kwargs == {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
@ -101,7 +106,7 @@ class TestHuggingFaceLocalGenerator:
model="google/flan-t5-base", model="google/flan-t5-base",
task="text2text-generation", task="text2text-generation",
device=ComponentDevice.from_str("cpu"), device=ComponentDevice.from_str("cpu"),
token="test-token", token=None,
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs, huggingface_pipeline_kwargs=huggingface_pipeline_kwargs,
) )
@ -142,10 +147,10 @@ class TestHuggingFaceLocalGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator", "type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": { "huggingface_pipeline_kwargs": {
"model": "google/flan-t5-base", "model": "google/flan-t5-base",
"task": "text2text-generation", "task": "text2text-generation",
"token": None,
"device": ComponentDevice.resolve_device(None).to_hf(), "device": ComponentDevice.resolve_device(None).to_hf(),
}, },
"generation_kwargs": {}, "generation_kwargs": {},
@ -158,7 +163,7 @@ class TestHuggingFaceLocalGenerator:
model="gpt2", model="gpt2",
task="text-generation", task="text-generation",
device=ComponentDevice.from_str("cuda:0"), device=ComponentDevice.from_str("cuda:0"),
token="test-token", token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"max_new_tokens": 100}, generation_kwargs={"max_new_tokens": 100},
stop_words=["coca", "cola"], stop_words=["coca", "cola"],
huggingface_pipeline_kwargs={ huggingface_pipeline_kwargs={
@ -175,6 +180,7 @@ class TestHuggingFaceLocalGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator", "type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": { "init_parameters": {
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"huggingface_pipeline_kwargs": { "huggingface_pipeline_kwargs": {
"model": "gpt2", "model": "gpt2",
"task": "text-generation", "task": "text-generation",
@ -197,7 +203,7 @@ class TestHuggingFaceLocalGenerator:
model="gpt2", model="gpt2",
task="text-generation", task="text-generation",
device=ComponentDevice.from_str("cuda:0"), device=ComponentDevice.from_str("cuda:0"),
token="test-token", token=None,
generation_kwargs={"max_new_tokens": 100}, generation_kwargs={"max_new_tokens": 100},
stop_words=["coca", "cola"], stop_words=["coca", "cola"],
huggingface_pipeline_kwargs={ huggingface_pipeline_kwargs={
@ -216,6 +222,7 @@ class TestHuggingFaceLocalGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator", "type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": { "init_parameters": {
"token": None,
"huggingface_pipeline_kwargs": { "huggingface_pipeline_kwargs": {
"model": "gpt2", "model": "gpt2",
"task": "text-generation", "task": "text-generation",
@ -239,6 +246,7 @@ class TestHuggingFaceLocalGenerator:
data = { data = {
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator", "type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
"init_parameters": { "init_parameters": {
"token": None,
"huggingface_pipeline_kwargs": { "huggingface_pipeline_kwargs": {
"model": "gpt2", "model": "gpt2",
"task": "text-generation", "task": "text-generation",
@ -276,9 +284,7 @@ class TestHuggingFaceLocalGenerator:
@patch("haystack.components.generators.hugging_face_local.pipeline") @patch("haystack.components.generators.hugging_face_local.pipeline")
def test_warm_up(self, pipeline_mock): def test_warm_up(self, pipeline_mock):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", task="text2text-generation", token=None)
model="google/flan-t5-base", task="text2text-generation", token="test-token"
)
pipeline_mock.assert_not_called() pipeline_mock.assert_not_called()
generator.warm_up() generator.warm_up()
@ -286,14 +292,14 @@ class TestHuggingFaceLocalGenerator:
pipeline_mock.assert_called_once_with( pipeline_mock.assert_called_once_with(
model="google/flan-t5-base", model="google/flan-t5-base",
task="text2text-generation", task="text2text-generation",
token="test-token", token=None,
device=ComponentDevice.resolve_device(None).to_hf(), device=ComponentDevice.resolve_device(None).to_hf(),
) )
@patch("haystack.components.generators.hugging_face_local.pipeline") @patch("haystack.components.generators.hugging_face_local.pipeline")
def test_warm_up_doesn_reload(self, pipeline_mock): def test_warm_up_doesn_reload(self, pipeline_mock):
generator = HuggingFaceLocalGenerator( generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base", task="text2text-generation", token="test-token" model="google/flan-t5-base", task="text2text-generation", token=Secret.from_token("fake-api-token")
) )
pipeline_mock.assert_not_called() pipeline_mock.assert_not_called()

View File

@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock, Mock from unittest.mock import patch, MagicMock, Mock
from haystack.utils.auth import Secret
import pytest import pytest
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
@ -59,7 +60,10 @@ class TestHuggingFaceTGIGenerator:
def test_to_dict(self, mock_check_valid_model): def test_to_dict(self, mock_check_valid_model):
# Initialize the HuggingFaceRemoteGenerator object with valid parameters # Initialize the HuggingFaceRemoteGenerator object with valid parameters
generator = HuggingFaceTGIGenerator( generator = HuggingFaceTGIGenerator(
token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"n": 5},
stop_words=["stop", "words"],
streaming_callback=lambda x: x,
) )
# Call the to_dict method # Call the to_dict method
@ -68,7 +72,7 @@ class TestHuggingFaceTGIGenerator:
# Assert that the init_params dictionary contains the expected keys and values # Assert that the init_params dictionary contains the expected keys and values
assert init_params["model"] == "mistralai/Mistral-7B-v0.1" assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
assert not init_params["token"] assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]} assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
def test_from_dict(self, mock_check_valid_model): def test_from_dict(self, mock_check_valid_model):

View File

@ -1,5 +1,6 @@
import os import os
from typing import List from typing import List
from haystack.utils.auth import Secret
import pytest import pytest
from openai import OpenAIError from openai import OpenAIError
@ -10,8 +11,9 @@ from haystack.dataclasses import StreamingChunk, ChatMessage
class TestOpenAIGenerator: class TestOpenAIGenerator:
def test_init_default(self): def test_init_default(self, monkeypatch):
component = OpenAIGenerator(api_key="test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator()
assert component.client.api_key == "test-api-key" assert component.client.api_key == "test-api-key"
assert component.model == "gpt-3.5-turbo" assert component.model == "gpt-3.5-turbo"
assert component.streaming_callback is None assert component.streaming_callback is None
@ -19,12 +21,12 @@ class TestOpenAIGenerator:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIGenerator() OpenAIGenerator()
def test_init_with_parameters(self): def test_init_with_parameters(self):
component = OpenAIGenerator( component = OpenAIGenerator(
api_key="test-api-key", api_key=Secret.from_token("test-api-key"),
model="gpt-4", model="gpt-4",
streaming_callback=print_streaming_chunk, streaming_callback=print_streaming_chunk,
api_base_url="test-base-url", api_base_url="test-base-url",
@ -35,12 +37,14 @@ class TestOpenAIGenerator:
assert component.streaming_callback is print_streaming_chunk assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
def test_to_dict_default(self): def test_to_dict_default(self, monkeypatch):
component = OpenAIGenerator(api_key="test-api-key") monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator()
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"streaming_callback": None, "streaming_callback": None,
"system_prompt": None, "system_prompt": None,
@ -49,9 +53,10 @@ class TestOpenAIGenerator:
}, },
} }
def test_to_dict_with_parameters(self): def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")
component = OpenAIGenerator( component = OpenAIGenerator(
api_key="test-api-key", api_key=Secret.from_env_var("ENV_VAR"),
model="gpt-4", model="gpt-4",
streaming_callback=print_streaming_chunk, streaming_callback=print_streaming_chunk,
api_base_url="test-base-url", api_base_url="test-base-url",
@ -61,6 +66,7 @@ class TestOpenAIGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"system_prompt": None, "system_prompt": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
@ -69,9 +75,9 @@ class TestOpenAIGenerator:
}, },
} }
def test_to_dict_with_lambda_streaming_callback(self): def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator( component = OpenAIGenerator(
api_key="test-api-key",
model="gpt-4", model="gpt-4",
streaming_callback=lambda x: x, streaming_callback=lambda x: x,
api_base_url="test-base-url", api_base_url="test-base-url",
@ -81,6 +87,7 @@ class TestOpenAIGenerator:
assert data == { assert data == {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"system_prompt": None, "system_prompt": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
@ -94,6 +101,7 @@ class TestOpenAIGenerator:
data = { data = {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"system_prompt": None, "system_prompt": None,
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
@ -106,23 +114,25 @@ class TestOpenAIGenerator:
assert component.streaming_callback is print_streaming_chunk assert component.streaming_callback is print_streaming_chunk
assert component.api_base_url == "test-base-url" assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
def test_from_dict_fail_wo_env_var(self, monkeypatch): def test_from_dict_fail_wo_env_var(self, monkeypatch):
monkeypatch.delenv("OPENAI_API_KEY", raising=False) monkeypatch.delenv("OPENAI_API_KEY", raising=False)
data = { data = {
"type": "haystack.components.generators.openai.OpenAIGenerator", "type": "haystack.components.generators.openai.OpenAIGenerator",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
"model": "gpt-4", "model": "gpt-4",
"api_base_url": "test-base-url", "api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
}, },
} }
with pytest.raises(OpenAIError): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
OpenAIGenerator.from_dict(data) OpenAIGenerator.from_dict(data)
def test_run(self, mock_chat_completion): def test_run(self, mock_chat_completion):
component = OpenAIGenerator(api_key="test-api-key") component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run("What's Natural Language Processing?") response = component.run("What's Natural Language Processing?")
# check that the component returns the correct ChatMessage response # check that the component returns the correct ChatMessage response
@ -139,7 +149,7 @@ class TestOpenAIGenerator:
nonlocal streaming_callback_called nonlocal streaming_callback_called
streaming_callback_called = True streaming_callback_called = True
component = OpenAIGenerator(streaming_callback=streaming_callback) component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback)
response = component.run("Come on, stream!") response = component.run("Come on, stream!")
# check we called the streaming callback # check we called the streaming callback
@ -153,7 +163,9 @@ class TestOpenAIGenerator:
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk
def test_run_with_params(self, mock_chat_completion): def test_run_with_params(self, mock_chat_completion):
component = OpenAIGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5}) component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
)
response = component.run("What's Natural Language Processing?") response = component.run("What's Natural Language Processing?")
# check that the component calls the OpenAI API with the correct parameters # check that the component calls the OpenAI API with the correct parameters
@ -169,7 +181,7 @@ class TestOpenAIGenerator:
assert [isinstance(reply, str) for reply in response["replies"]] assert [isinstance(reply, str) for reply in response["replies"]]
def test_check_abnormal_completions(self, caplog): def test_check_abnormal_completions(self, caplog):
component = OpenAIGenerator(api_key="test-api-key") component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
# underlying implementation uses ChatMessage objects so we have to use them here # underlying implementation uses ChatMessage objects so we have to use them here
messages: List[ChatMessage] = [] messages: List[ChatMessage] = []
@ -202,7 +214,7 @@ class TestOpenAIGenerator:
) )
@pytest.mark.integration @pytest.mark.integration
def test_live_run(self): def test_live_run(self):
component = OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY")) component = OpenAIGenerator()
results = component.run("What's the capital of France?") results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1 assert len(results["replies"]) == 1
assert len(results["meta"]) == 1 assert len(results["meta"]) == 1
@ -224,7 +236,7 @@ class TestOpenAIGenerator:
) )
@pytest.mark.integration @pytest.mark.integration
def test_live_run_wrong_model(self): def test_live_run_wrong_model(self):
component = OpenAIGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY")) component = OpenAIGenerator(model="something-obviously-wrong")
with pytest.raises(OpenAIError): with pytest.raises(OpenAIError):
component.run("Whatever") component.run("Whatever")
@ -244,7 +256,7 @@ class TestOpenAIGenerator:
self.responses += chunk.content if chunk.content else "" self.responses += chunk.content if chunk.content else ""
callback = Callback() callback = Callback()
component = OpenAIGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) component = OpenAIGenerator(streaming_callback=callback)
results = component.run("What's the capital of France?") results = component.run("What's the capital of France?")
assert len(results["replies"]) == 1 assert len(results["replies"]) == 1

View File

@ -1,4 +1,5 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from haystack.utils.auth import Secret
import pytest import pytest
import logging import logging
@ -19,7 +20,7 @@ class TestSimilarityRanker:
"init_parameters": { "init_parameters": {
"device": None, "device": None,
"top_k": 10, "top_k": 10,
"token": None, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"query_prefix": "", "query_prefix": "",
"document_prefix": "", "document_prefix": "",
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2", "model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
@ -36,7 +37,7 @@ class TestSimilarityRanker:
component = TransformersSimilarityRanker( component = TransformersSimilarityRanker(
model="my_model", model="my_model",
device=ComponentDevice.from_str("cuda:0"), device=ComponentDevice.from_str("cuda:0"),
token="my_token", token=Secret.from_env_var("ENV_VAR", strict=False),
top_k=5, top_k=5,
query_prefix="query_instruction: ", query_prefix="query_instruction: ",
document_prefix="document_instruction: ", document_prefix="document_instruction: ",
@ -51,7 +52,7 @@ class TestSimilarityRanker:
"init_parameters": { "init_parameters": {
"device": None, "device": None,
"model": "my_model", "model": "my_model",
"token": None, # we don't serialize valid tokens, "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"top_k": 5, "top_k": 5,
"query_prefix": "query_instruction: ", "query_prefix": "query_instruction: ",
"document_prefix": "document_instruction: ", "document_prefix": "document_instruction: ",
@ -84,7 +85,7 @@ class TestSimilarityRanker:
"top_k": 10, "top_k": 10,
"query_prefix": "", "query_prefix": "",
"document_prefix": "", "document_prefix": "",
"token": None, "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2", "model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"meta_fields_to_embed": [], "meta_fields_to_embed": [],
"embedding_separator": "\n", "embedding_separator": "\n",
@ -110,7 +111,7 @@ class TestSimilarityRanker:
], ],
) )
def test_to_dict_device_map(self, device_map, expected): def test_to_dict_device_map(self, device_map, expected):
component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map}) component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map}, token=None)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {

View File

@ -1,5 +1,6 @@
import os import os
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from haystack.utils.auth import Secret
import pytest import pytest
from requests import Timeout, RequestException, HTTPError from requests import Timeout, RequestException, HTTPError
@ -369,17 +370,19 @@ def mock_searchapi_search_result():
class TestSearchApiSearchAPI: class TestSearchApiSearchAPI:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("SEARCHAPI_API_KEY", raising=False) monkeypatch.delenv("SEARCHAPI_API_KEY", raising=False)
with pytest.raises(ValueError, match="SearchApiWebSearch expects an API key"): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
SearchApiWebSearch() SearchApiWebSearch()
def test_to_dict(self): def test_to_dict(self, monkeypatch):
monkeypatch.setenv("SEARCHAPI_API_KEY", "test-api-key")
component = SearchApiWebSearch( component = SearchApiWebSearch(
api_key="api_key", top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"} top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"}
) )
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.websearch.searchapi.SearchApiWebSearch", "type": "haystack.components.websearch.searchapi.SearchApiWebSearch",
"init_parameters": { "init_parameters": {
"api_key": {"env_vars": ["SEARCHAPI_API_KEY"], "strict": True, "type": "env_var"},
"top_k": 10, "top_k": 10,
"allowed_domains": ["testdomain.com"], "allowed_domains": ["testdomain.com"],
"search_params": {"param": "test params"}, "search_params": {"param": "test params"},
@ -388,7 +391,7 @@ class TestSearchApiSearchAPI:
@pytest.mark.parametrize("top_k", [1, 5, 7]) @pytest.mark.parametrize("top_k", [1, 5, 7])
def test_web_search_top_k(self, mock_searchapi_search_result, top_k: int): def test_web_search_top_k(self, mock_searchapi_search_result, top_k: int):
ws = SearchApiWebSearch(api_key="api_key", top_k=top_k) ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"), top_k=top_k)
results = ws.run(query="Who is CEO of Microsoft?") results = ws.run(query="Who is CEO of Microsoft?")
documents = results["documents"] documents = results["documents"]
links = results["links"] links = results["links"]
@ -400,7 +403,7 @@ class TestSearchApiSearchAPI:
@patch("requests.get") @patch("requests.get")
def test_timeout_error(self, mock_get): def test_timeout_error(self, mock_get):
mock_get.side_effect = Timeout mock_get.side_effect = Timeout
ws = SearchApiWebSearch(api_key="api_key") ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
ws.run(query="Who is CEO of Microsoft?") ws.run(query="Who is CEO of Microsoft?")
@ -408,7 +411,7 @@ class TestSearchApiSearchAPI:
@patch("requests.get") @patch("requests.get")
def test_request_exception(self, mock_get): def test_request_exception(self, mock_get):
mock_get.side_effect = RequestException mock_get.side_effect = RequestException
ws = SearchApiWebSearch(api_key="api_key") ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SearchApiError): with pytest.raises(SearchApiError):
ws.run(query="Who is CEO of Microsoft?") ws.run(query="Who is CEO of Microsoft?")
@ -418,7 +421,7 @@ class TestSearchApiSearchAPI:
mock_response = mock_get.return_value mock_response = mock_get.return_value
mock_response.status_code = 404 mock_response.status_code = 404
mock_response.raise_for_status.side_effect = HTTPError mock_response.raise_for_status.side_effect = HTTPError
ws = SearchApiWebSearch(api_key="api_key") ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SearchApiError): with pytest.raises(SearchApiError):
ws.run(query="Who is CEO of Microsoft?") ws.run(query="Who is CEO of Microsoft?")
@ -429,7 +432,7 @@ class TestSearchApiSearchAPI:
) )
@pytest.mark.integration @pytest.mark.integration
def test_web_search(self): def test_web_search(self):
ws = SearchApiWebSearch(api_key=os.environ.get("SEARCHAPI_API_KEY", None), top_k=10) ws = SearchApiWebSearch(top_k=10)
results = ws.run(query="Who is CEO of Microsoft?") results = ws.run(query="Who is CEO of Microsoft?")
documents = results["documents"] documents = results["documents"]
links = results["links"] links = results["links"]

View File

@ -1,5 +1,6 @@
import os import os
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from haystack.utils.auth import Secret
import pytest import pytest
from requests import Timeout, RequestException, HTTPError from requests import Timeout, RequestException, HTTPError
@ -110,22 +111,26 @@ def mock_serper_dev_search_result():
class TestSerperDevSearchAPI: class TestSerperDevSearchAPI:
def test_init_fail_wo_api_key(self, monkeypatch): def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("SERPERDEV_API_KEY", raising=False) monkeypatch.delenv("SERPERDEV_API_KEY", raising=False)
with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"): with pytest.raises(ValueError, match="None of the .* environment variables are set"):
SerperDevWebSearch() SerperDevWebSearch()
def test_to_dict(self): def test_to_dict(self, monkeypatch):
component = SerperDevWebSearch( monkeypatch.setenv("SERPERDEV_API_KEY", "test-api-key")
api_key="test_key", top_k=10, allowed_domains=["test.com"], search_params={"param": "test"} component = SerperDevWebSearch(top_k=10, allowed_domains=["test.com"], search_params={"param": "test"})
)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "haystack.components.websearch.serper_dev.SerperDevWebSearch", "type": "haystack.components.websearch.serper_dev.SerperDevWebSearch",
"init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}}, "init_parameters": {
"api_key": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
"top_k": 10,
"allowed_domains": ["test.com"],
"search_params": {"param": "test"},
},
} }
@pytest.mark.parametrize("top_k", [1, 5, 7]) @pytest.mark.parametrize("top_k", [1, 5, 7])
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int): def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):
ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k) ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"), top_k=top_k)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?") results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"] documents = results["documents"]
links = results["links"] links = results["links"]
@ -137,7 +142,7 @@ class TestSerperDevSearchAPI:
@patch("requests.post") @patch("requests.post")
def test_timeout_error(self, mock_post): def test_timeout_error(self, mock_post):
mock_post.side_effect = Timeout mock_post.side_effect = Timeout
ws = SerperDevWebSearch(api_key="some_invalid_key") ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
ws.run(query="Who is the boyfriend of Olivia Wilde?") ws.run(query="Who is the boyfriend of Olivia Wilde?")
@ -145,7 +150,7 @@ class TestSerperDevSearchAPI:
@patch("requests.post") @patch("requests.post")
def test_request_exception(self, mock_post): def test_request_exception(self, mock_post):
mock_post.side_effect = RequestException mock_post.side_effect = RequestException
ws = SerperDevWebSearch(api_key="some_invalid_key") ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SerperDevError): with pytest.raises(SerperDevError):
ws.run(query="Who is the boyfriend of Olivia Wilde?") ws.run(query="Who is the boyfriend of Olivia Wilde?")
@ -155,7 +160,7 @@ class TestSerperDevSearchAPI:
mock_response = mock_post.return_value mock_response = mock_post.return_value
mock_response.status_code = 404 mock_response.status_code = 404
mock_response.raise_for_status.side_effect = HTTPError mock_response.raise_for_status.side_effect = HTTPError
ws = SerperDevWebSearch(api_key="some_invalid_key") ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
with pytest.raises(SerperDevError): with pytest.raises(SerperDevError):
ws.run(query="Who is the boyfriend of Olivia Wilde?") ws.run(query="Who is the boyfriend of Olivia Wilde?")
@ -166,7 +171,7 @@ class TestSerperDevSearchAPI:
) )
@pytest.mark.integration @pytest.mark.integration
def test_web_search(self): def test_web_search(self):
ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10) ws = SerperDevWebSearch(top_k=10)
results = ws.run(query="Who is the boyfriend of Olivia Wilde?") results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
documents = results["documents"] documents = results["documents"]
links = results["documents"] links = results["documents"]