mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-16 01:27:41 +00:00
feat!: Use Secret for passing authentication secrets to components (#6887)
* feat!: Use `Secret` for passing authentication secrets to components * Add comment to clarify type ignore
This commit is contained in:
parent
393a7993c3
commit
27d1af3068
@ -7,6 +7,7 @@ from openai import OpenAI
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict
|
||||
from haystack.dataclasses import ByteStream
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -24,7 +25,7 @@ class RemoteWhisperTranscriber:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "whisper-1",
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
@ -61,6 +62,7 @@ class RemoteWhisperTranscriber:
|
||||
self.organization = organization
|
||||
self.model = model
|
||||
self.api_base_url = api_base_url
|
||||
self.api_key = api_key
|
||||
|
||||
# Only response_format = "json" is supported
|
||||
whisper_params = kwargs
|
||||
@ -71,7 +73,7 @@ class RemoteWhisperTranscriber:
|
||||
)
|
||||
whisper_params["response_format"] = "json"
|
||||
self.whisper_params = whisper_params
|
||||
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
|
||||
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -81,6 +83,7 @@ class RemoteWhisperTranscriber:
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
api_key=self.api_key.to_dict(),
|
||||
model=self.model,
|
||||
organization=self.organization,
|
||||
api_base_url=self.api_base_url,
|
||||
@ -92,6 +95,7 @@ class RemoteWhisperTranscriber:
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
|
||||
@ -29,10 +29,11 @@ class DynamicChatPromptBuilder:
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack import Pipeline
|
||||
from haystack.utils import Secret
|
||||
|
||||
# no parameter init, we don't use any runtime template variables
|
||||
prompt_builder = DynamicChatPromptBuilder()
|
||||
llm = OpenAIChatGenerator(api_key="<your-api-key>", model="gpt-3.5-turbo")
|
||||
llm = OpenAIChatGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("prompt_builder", prompt_builder)
|
||||
|
||||
@ -21,9 +21,10 @@ class DynamicPromptBuilder:
|
||||
from haystack.components.builders import DynamicPromptBuilder
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack import Pipeline, component, Document
|
||||
from haystack.utils import Secret
|
||||
|
||||
prompt_builder = DynamicPromptBuilder(runtime_variables=["documents"])
|
||||
llm = OpenAIGenerator(api_key="<your-api-key>", model="gpt-3.5-turbo")
|
||||
llm = OpenAIGenerator(api_key=Secret.from_token("<your-api-key>"), model="gpt-3.5-turbo")
|
||||
|
||||
|
||||
@component
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Dict, Any, Optional
|
||||
import logging
|
||||
import os
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack import component, Document, default_to_dict
|
||||
from haystack import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.dataclasses import ByteStream
|
||||
from haystack.components.converters.utils import get_bytestream_from_source, normalize_metadata
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -29,8 +29,9 @@ class AzureOCRDocumentConverter:
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.converters.azure import AzureOCRDocumentConverter
|
||||
from haystack.utils import Secret
|
||||
|
||||
converter = AzureOCRDocumentConverter()
|
||||
converter = AzureOCRDocumentConverter(endpoint="<url>", api_key=Secret.from_token("<your-api-key>"))
|
||||
results = converter.run(sources=["image-based-document.pdf"], meta={"date_added": datetime.now().isoformat()})
|
||||
documents = results["documents"]
|
||||
print(documents[0].content)
|
||||
@ -38,33 +39,23 @@ class AzureOCRDocumentConverter:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: str, api_key: Optional[str] = None, model_id: str = "prebuilt-read"):
|
||||
def __init__(
|
||||
self, endpoint: str, api_key: Secret = Secret.from_env_var("AZURE_AI_API_KEY"), model_id: str = "prebuilt-read"
|
||||
):
|
||||
"""
|
||||
Create an AzureOCRDocumentConverter component.
|
||||
|
||||
:param endpoint: The endpoint of your Azure resource.
|
||||
:param api_key: The key of your Azure resource. It can be
|
||||
explicitly provided or automatically read from the
|
||||
environment variable AZURE_AI_API_KEY (recommended).
|
||||
:param api_key: The key of your Azure resource.
|
||||
:param model_id: The model ID of the model you want to use. Please refer to [Azure documentation](https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/choose-model-feature)
|
||||
for a list of available models. Default: `"prebuilt-read"`.
|
||||
"""
|
||||
azure_import.check()
|
||||
|
||||
api_key = api_key or os.environ.get("AZURE_AI_API_KEY")
|
||||
# we check whether api_key is None or an empty string
|
||||
if not api_key:
|
||||
msg = (
|
||||
"AzureOCRDocumentConverter expects an API key. "
|
||||
"Set the AZURE_AI_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.document_analysis_client = DocumentAnalysisClient(
|
||||
endpoint=endpoint, credential=AzureKeyCredential(api_key)
|
||||
)
|
||||
self.document_analysis_client = DocumentAnalysisClient(endpoint=endpoint, credential=AzureKeyCredential(api_key.resolve_value())) # type: ignore
|
||||
self.endpoint = endpoint
|
||||
self.model_id = model_id
|
||||
self.api_key = api_key
|
||||
|
||||
@component.output_types(documents=List[Document], raw_azure_response=List[Dict])
|
||||
def run(self, sources: List[Union[str, Path, ByteStream]], meta: Optional[List[Dict[str, Any]]] = None):
|
||||
@ -116,7 +107,15 @@ class AzureOCRDocumentConverter:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(self, endpoint=self.endpoint, model_id=self.model_id)
|
||||
return default_to_dict(self, api_key=self.api_key.to_dict(), endpoint=self.endpoint, model_id=self.model_id)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AzureOCRDocumentConverter":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@staticmethod
|
||||
def _convert_azure_result_to_document(result: "AnalyzeResult", file_suffix: Optional[str] = None) -> Document:
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any, Tuple
|
||||
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import component, Document, default_to_dict
|
||||
from haystack import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@component
|
||||
@ -34,9 +35,8 @@ class AzureOpenAIDocumentEmbedder:
|
||||
azure_endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = "2023-05-15",
|
||||
azure_deployment: str = "text-embedding-ada-002",
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
||||
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
@ -53,8 +53,6 @@ class AzureOpenAIDocumentEmbedder:
|
||||
:param azure_deployment: The deployment of the model, usually the model name.
|
||||
:param api_key: The API key to use for authentication.
|
||||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||
on every request.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
@ -70,6 +68,11 @@ class AzureOpenAIDocumentEmbedder:
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||
|
||||
if api_key is None and azure_ad_token is None:
|
||||
raise ValueError("Please provide an API key or an Azure Active Directory token.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.azure_ad_token = azure_ad_token
|
||||
self.api_version = api_version
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_deployment = azure_deployment
|
||||
@ -85,9 +88,8 @@ class AzureOpenAIDocumentEmbedder:
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_deployment=azure_deployment,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_key=api_key.resolve_value() if api_key is not None else None,
|
||||
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
|
||||
organization=organization,
|
||||
)
|
||||
|
||||
@ -98,10 +100,6 @@ class AzureOpenAIDocumentEmbedder:
|
||||
return {"model": self.azure_deployment}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
azure_endpoint=self.azure_endpoint,
|
||||
@ -114,8 +112,15 @@ class AzureOpenAIDocumentEmbedder:
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from haystack import component, default_to_dict, Document
|
||||
from haystack import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@component
|
||||
@ -32,9 +33,8 @@ class AzureOpenAITextEmbedder:
|
||||
azure_endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = "2023-05-15",
|
||||
azure_deployment: str = "text-embedding-ada-002",
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
||||
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
||||
organization: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
@ -47,8 +47,6 @@ class AzureOpenAITextEmbedder:
|
||||
:param azure_deployment: The deployment of the model, usually the model name.
|
||||
:param api_key: The API key to use for authentication.
|
||||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||
on every request.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
@ -62,6 +60,11 @@ class AzureOpenAITextEmbedder:
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||
|
||||
if api_key is None and azure_ad_token is None:
|
||||
raise ValueError("Please provide an API key or an Azure Active Directory token.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.azure_ad_token = azure_ad_token
|
||||
self.api_version = api_version
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_deployment = azure_deployment
|
||||
@ -73,9 +76,8 @@ class AzureOpenAITextEmbedder:
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_deployment=azure_deployment,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_key=api_key.resolve_value() if api_key is not None else None,
|
||||
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
|
||||
organization=organization,
|
||||
)
|
||||
|
||||
@ -98,8 +100,15 @@ class AzureOpenAITextEmbedder:
|
||||
api_version=self.api_version,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAITextEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(embedding=List[float], meta=Dict[str, Any])
|
||||
def run(self, text: str):
|
||||
"""Embed a string using AzureOpenAITextEmbedder."""
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import List, Optional, Union, Dict
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as sentence_transformers_import:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
@ -14,14 +15,12 @@ class _SentenceTransformersEmbeddingBackendFactory:
|
||||
_instances: Dict[str, "_SentenceTransformersEmbeddingBackend"] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_backend(model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
|
||||
embedding_backend_id = f"{model}{device}{use_auth_token}"
|
||||
def get_embedding_backend(model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
|
||||
embedding_backend_id = f"{model}{device}{auth_token}"
|
||||
|
||||
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
|
||||
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
|
||||
embedding_backend = _SentenceTransformersEmbeddingBackend(
|
||||
model=model, device=device, use_auth_token=use_auth_token
|
||||
)
|
||||
embedding_backend = _SentenceTransformersEmbeddingBackend(model=model, device=device, auth_token=auth_token)
|
||||
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
|
||||
return embedding_backend
|
||||
|
||||
@ -31,9 +30,11 @@ class _SentenceTransformersEmbeddingBackend:
|
||||
Class to manage Sentence Transformers embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None):
|
||||
def __init__(self, model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None):
|
||||
sentence_transformers_import.check()
|
||||
self.model = SentenceTransformer(model_name_or_path=model, device=device, use_auth_token=use_auth_token)
|
||||
self.model = SentenceTransformer(
|
||||
model_name_or_path=model, device=device, use_auth_token=auth_token.resolve_value() if auth_token else None
|
||||
)
|
||||
|
||||
def embed(self, data: List[str], **kwargs) -> List[List[float]]:
|
||||
embeddings = self.model.encode(data, **kwargs).tolist()
|
||||
|
||||
@ -1,26 +1,27 @@
|
||||
from typing import Optional
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
|
||||
def check_valid_model(model_id: str, token: Optional[str]) -> None:
|
||||
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
|
||||
"""
|
||||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
||||
Also check if the model is a embedding model.
|
||||
|
||||
:param model_id: A string representing the HuggingFace model ID.
|
||||
:param token: An optional string representing the authentication token.
|
||||
:param token: The optional authentication token.
|
||||
:raises ValueError: If the model is not found or is not a embedding model.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
model_info = api.model_info(model_id, token=token)
|
||||
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
|
||||
except RepositoryNotFoundError as e:
|
||||
raise ValueError(
|
||||
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
|
||||
|
||||
@ -1,14 +1,14 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import component, default_to_dict
|
||||
from haystack.components.embedders.hf_utils import check_valid_model
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -29,11 +29,12 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
```python
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.components.embedders import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
document_embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", token="<your-token>"
|
||||
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
|
||||
)
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
@ -52,7 +53,7 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
doc = Document(content="I love pizza!")
|
||||
|
||||
document_embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token="<your-token>"
|
||||
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token=Secret.from_token("<your-api-key>")
|
||||
)
|
||||
|
||||
result = document_embedder.run([doc])
|
||||
@ -83,7 +84,7 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
self,
|
||||
model: str = "BAAI/bge-small-en-v1.5",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
batch_size: int = 32,
|
||||
@ -98,8 +99,7 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
:param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
|
||||
Endpoint.
|
||||
:param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
|
||||
a private or gated model. It can be explicitly provided or automatically read from the environment
|
||||
variable HF_API_TOKEN (recommended).
|
||||
a private or gated model.
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
:param suffix: A string to add to the end of each text.
|
||||
:param batch_size: Number of Documents to encode at once.
|
||||
@ -116,15 +116,12 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
||||
|
||||
# The user does not need to provide a token if it is a local server or free public HF Inference Endpoint.
|
||||
token = token or os.environ.get("HF_API_TOKEN")
|
||||
|
||||
check_valid_model(model, token)
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.client = InferenceClient(url or model, token=token)
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.batch_size = batch_size
|
||||
@ -133,10 +130,6 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `token` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
@ -147,8 +140,14 @@ class HuggingFaceTEIDocumentEmbedder:
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEIDocumentEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Data that is sent to Posthog for usage analytics.
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from haystack import component, default_to_dict
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.embedders.hf_utils import check_valid_model
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -23,11 +23,12 @@ class HuggingFaceTEITextEmbedder:
|
||||
Inference API tier:
|
||||
```python
|
||||
from haystack.components.embedders import HuggingFaceTEITextEmbedder
|
||||
from haystack.utils import Secret
|
||||
|
||||
text_to_embed = "I love pizza!"
|
||||
|
||||
text_embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", token="<your-token>"
|
||||
model="BAAI/bge-small-en-v1.5", token=Secret.from_token("<your-api-key>")
|
||||
)
|
||||
|
||||
print(text_embedder.run(text_to_embed))
|
||||
@ -44,7 +45,7 @@ class HuggingFaceTEITextEmbedder:
|
||||
text_to_embed = "I love pizza!"
|
||||
|
||||
text_embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token="<your-token>"
|
||||
model="BAAI/bge-small-en-v1.5", url="<your-tei-endpoint-url>", token=Secret.from_token("<your-api-key>")
|
||||
)
|
||||
|
||||
print(text_embedder.run(text_to_embed))
|
||||
@ -74,7 +75,7 @@ class HuggingFaceTEITextEmbedder:
|
||||
self,
|
||||
model: str = "BAAI/bge-small-en-v1.5",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
):
|
||||
@ -85,8 +86,7 @@ class HuggingFaceTEITextEmbedder:
|
||||
:param url: The URL of your self-deployed Text-Embeddings-Inference service or the URL of your paid HF Inference
|
||||
Endpoint.
|
||||
:param token: The HuggingFace Hub token. This is needed if you are using a paid HF Inference Endpoint or serving
|
||||
a private or gated model. It can be explicitly provided or automatically read from the environment
|
||||
variable HF_API_TOKEN (recommended).
|
||||
a private or gated model.
|
||||
:param prefix: A string to add to the beginning of each text.
|
||||
:param suffix: A string to add to the end of each text.
|
||||
"""
|
||||
@ -98,24 +98,29 @@ class HuggingFaceTEITextEmbedder:
|
||||
if not is_valid_url:
|
||||
raise ValueError(f"Invalid TEI endpoint URL provided: {url}")
|
||||
|
||||
# The user does not need to provide a token if it is a local server or free public HF Inference Endpoint.
|
||||
token = token or os.environ.get("HF_API_TOKEN")
|
||||
|
||||
check_valid_model(model, token)
|
||||
|
||||
self.model = model
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.client = InferenceClient(url or model, token=token)
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `token` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
return default_to_dict(self, model=self.model, url=self.url, prefix=self.prefix, suffix=self.suffix)
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceTEITextEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,8 @@ from typing import List, Optional, Dict, Any, Tuple
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import component, Document, default_to_dict
|
||||
from haystack import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@component
|
||||
@ -30,7 +31,7 @@ class OpenAIDocumentEmbedder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "text-embedding-ada-002",
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
@ -43,8 +44,7 @@ class OpenAIDocumentEmbedder:
|
||||
):
|
||||
"""
|
||||
Create a OpenAIDocumentEmbedder component.
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model: The name of the model to use.
|
||||
:param api_base_url: The OpenAI API Base url, defaults to None. For more details, see OpenAI [docs](https://platform.openai.com/docs/api-reference/audio).
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
@ -57,6 +57,7 @@ class OpenAIDocumentEmbedder:
|
||||
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.api_base_url = api_base_url
|
||||
self.organization = organization
|
||||
@ -67,7 +68,7 @@ class OpenAIDocumentEmbedder:
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
|
||||
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -91,8 +92,14 @@ class OpenAIDocumentEmbedder:
|
||||
progress_bar=self.progress_bar,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
|
||||
@ -2,7 +2,8 @@ from typing import List, Optional, Dict, Any
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from haystack import component, default_to_dict
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@component
|
||||
@ -28,7 +29,7 @@ class OpenAITextEmbedder:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "text-embedding-ada-002",
|
||||
api_base_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
@ -38,8 +39,7 @@ class OpenAITextEmbedder:
|
||||
"""
|
||||
Create an OpenAITextEmbedder component.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model: The name of the OpenAI model to use. For more details on the available models,
|
||||
see [OpenAI documentation](https://platform.openai.com/docs/guides/embeddings/embedding-models).
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
@ -52,8 +52,9 @@ class OpenAITextEmbedder:
|
||||
self.organization = organization
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self.api_key = api_key
|
||||
|
||||
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
|
||||
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -62,15 +63,20 @@ class OpenAITextEmbedder:
|
||||
return {"model": self.model}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
|
||||
to the constructor.
|
||||
"""
|
||||
|
||||
return default_to_dict(
|
||||
self, model=self.model, organization=self.organization, prefix=self.prefix, suffix=self.suffix
|
||||
self,
|
||||
model=self.model,
|
||||
organization=self.organization,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "OpenAITextEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(embedding=List[float], meta=Dict[str, Any])
|
||||
def run(self, text: str):
|
||||
"""Embed a string."""
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from haystack import component, Document, default_to_dict
|
||||
from haystack import component, Document, default_to_dict, default_from_dict
|
||||
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@component
|
||||
@ -31,7 +32,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self,
|
||||
model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
device: Optional[str] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
batch_size: int = 32,
|
||||
@ -48,8 +49,6 @@ class SentenceTransformersDocumentEmbedder:
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
||||
Defaults to CPU.
|
||||
:param token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
:param prefix: A string to add to the beginning of each Document text before embedding.
|
||||
Can be used to prepend the text with an instruction, as required by some embedding models,
|
||||
such as E5 and bge.
|
||||
@ -87,7 +86,7 @@ class SentenceTransformersDocumentEmbedder:
|
||||
self,
|
||||
model=self.model,
|
||||
device=self.device,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
batch_size=self.batch_size,
|
||||
@ -97,13 +96,18 @@ class SentenceTransformersDocumentEmbedder:
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
"""
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model=self.model, device=self.device, use_auth_token=self.token
|
||||
model=self.model, device=self.device, auth_token=self.token
|
||||
)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from haystack import component, default_to_dict
|
||||
from haystack import component, default_to_dict, default_from_dict
|
||||
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
|
||||
@component
|
||||
@ -30,7 +31,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self,
|
||||
model: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
device: Optional[str] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
batch_size: int = 32,
|
||||
@ -45,8 +46,6 @@ class SentenceTransformersTextEmbedder:
|
||||
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
||||
Defaults to CPU.
|
||||
:param token: The API token used to download private models from Hugging Face.
|
||||
If this parameter is set to `True`, then the token generated when running
|
||||
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
||||
:param prefix: A string to add to the beginning of each Document text before embedding.
|
||||
Can be used to prepend the text with an instruction, as required by some embedding models,
|
||||
such as E5 and bge.
|
||||
@ -80,7 +79,7 @@ class SentenceTransformersTextEmbedder:
|
||||
self,
|
||||
model=self.model,
|
||||
device=self.device,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
batch_size=self.batch_size,
|
||||
@ -88,13 +87,18 @@ class SentenceTransformersTextEmbedder:
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Load the embedding backend.
|
||||
"""
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model=self.model, device=self.device, use_auth_token=self.token
|
||||
model=self.model, device=self.device, auth_token=self.token
|
||||
)
|
||||
|
||||
@component.output_types(embedding=List[float])
|
||||
|
||||
@ -3,12 +3,13 @@ import os
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
|
||||
# pylint: disable=import-error
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from haystack import default_to_dict, default_from_dict
|
||||
from haystack.components.generators import OpenAIGenerator
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,8 +29,9 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
|
||||
```python
|
||||
from haystack.components.generators import AzureOpenAIGenerator
|
||||
from haystack.utils import Secret
|
||||
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
|
||||
api_key="<you api key>",
|
||||
api_key=Secret.from_token("<your-api-key>"),
|
||||
azure_deployment="<this a model name, e.g. gpt-35-turbo>")
|
||||
response = client.run("What's Natural Language Processing? Be brief.")
|
||||
print(response)
|
||||
@ -56,9 +58,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
azure_endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = "2023-05-15",
|
||||
azure_deployment: Optional[str] = "gpt-35-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
||||
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
||||
organization: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
@ -70,8 +71,6 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
:param azure_deployment: The deployment of the model, usually the model name.
|
||||
:param api_key: The API key to use for authentication.
|
||||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||
on every request.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
@ -108,6 +107,13 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||
|
||||
if api_key is None and azure_ad_token is None:
|
||||
raise ValueError("Please provide an API key or an Azure Active Directory token.")
|
||||
|
||||
# The check above makes mypy incorrectly infer that api_key is never None,
|
||||
# which propagates the incorrect type.
|
||||
self.api_key = api_key # type: ignore
|
||||
self.azure_ad_token = azure_ad_token
|
||||
self.generation_kwargs = generation_kwargs or {}
|
||||
self.system_prompt = system_prompt
|
||||
self.streaming_callback = streaming_callback
|
||||
@ -121,9 +127,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_deployment=azure_deployment,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_key=api_key.resolve_value() if api_key is not None else None,
|
||||
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
|
||||
organization=organization,
|
||||
)
|
||||
|
||||
@ -142,6 +147,8 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
streaming_callback=callback_name,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
system_prompt=self.system_prompt,
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -151,6 +158,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
|
||||
:param data: The dictionary representation of this component.
|
||||
:return: The deserialized component instance.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
|
||||
@ -3,12 +3,13 @@ import os
|
||||
from typing import Optional, Callable, Dict, Any
|
||||
|
||||
# pylint: disable=import-error
|
||||
from openai.lib.azure import AzureADTokenProvider, AzureOpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
|
||||
from haystack import default_to_dict, default_from_dict
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,11 +29,12 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
```python
|
||||
from haystack.components.generators.chat import AzureOpenAIGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.utils import Secret
|
||||
|
||||
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
client = AzureOpenAIGenerator(azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
|
||||
api_key="<you api key>",
|
||||
api_key=Secret.from_token("<your-api-key>"),
|
||||
azure_deployment="<this a model name, e.g. gpt-35-turbo>")
|
||||
response = client.run(messages)
|
||||
print(response)
|
||||
@ -63,9 +65,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
azure_endpoint: Optional[str] = None,
|
||||
api_version: Optional[str] = "2023-05-15",
|
||||
azure_deployment: Optional[str] = "gpt-35-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
|
||||
api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
|
||||
azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
|
||||
organization: Optional[str] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -76,8 +77,6 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
:param azure_deployment: The deployment of the model, usually the model name.
|
||||
:param api_key: The API key to use for authentication.
|
||||
:param azure_ad_token: Azure Active Directory token, see https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
:param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked
|
||||
on every request.
|
||||
:param organization: The Organization ID, defaults to `None`. See
|
||||
[production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
@ -113,6 +112,13 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
if not azure_endpoint:
|
||||
raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
|
||||
|
||||
if api_key is None and azure_ad_token is None:
|
||||
raise ValueError("Please provide an API key or an Azure Active Directory token.")
|
||||
|
||||
# The check above makes mypy incorrectly infer that api_key is never None,
|
||||
# which propagates the incorrect type.
|
||||
self.api_key = api_key # type: ignore
|
||||
self.azure_ad_token = azure_ad_token
|
||||
self.generation_kwargs = generation_kwargs or {}
|
||||
self.streaming_callback = streaming_callback
|
||||
self.api_version = api_version
|
||||
@ -125,9 +131,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
api_version=api_version,
|
||||
azure_endpoint=azure_endpoint,
|
||||
azure_deployment=azure_deployment,
|
||||
api_key=api_key,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
api_key=api_key.resolve_value() if api_key is not None else None,
|
||||
azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
|
||||
organization=organization,
|
||||
)
|
||||
|
||||
@ -145,6 +150,8 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
api_version=self.api_version,
|
||||
streaming_callback=callback_name,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
api_key=self.api_key.to_dict() if self.api_key is not None else None,
|
||||
azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -154,6 +161,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
|
||||
:param data: The dictionary representation of this component.
|
||||
:return: The deserialized component instance.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
|
||||
@ -9,6 +9,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -57,7 +58,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
model: str = "HuggingFaceH4/zephyr-7b-beta",
|
||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
chat_template: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@ -80,7 +81,6 @@ class HuggingFaceLocalChatGenerator:
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token: The token to use as HTTP bearer authorization for remote files.
|
||||
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
|
||||
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||
:param chat_template: This optional parameter allows you to specify a Jinja template for formatting chat
|
||||
messages. While high-quality and well-supported chat models typically include their own chat templates
|
||||
@ -113,6 +113,9 @@ class HuggingFaceLocalChatGenerator:
|
||||
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
||||
generation_kwargs = generation_kwargs or {}
|
||||
|
||||
self.token = token
|
||||
token = token.resolve_value() if token else None
|
||||
|
||||
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
||||
# otherwise, populate them with values from other init parameters
|
||||
huggingface_pipeline_kwargs.setdefault("model", model)
|
||||
@ -178,12 +181,11 @@ class HuggingFaceLocalChatGenerator:
|
||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
||||
# we don't want to serialize valid tokens
|
||||
if isinstance(huggingface_pipeline_kwargs["token"], str):
|
||||
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
|
||||
huggingface_pipeline_kwargs.pop("token", None)
|
||||
|
||||
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
||||
return serialization_dict
|
||||
@ -194,7 +196,7 @@ class HuggingFaceLocalChatGenerator:
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
torch_and_transformers_import.check() # leave this, cls method
|
||||
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
|
||||
@ -8,6 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
|
||||
from haystack.dataclasses import ChatMessage, StreamingChunk
|
||||
from haystack.components.generators.hf_utils import check_valid_model, check_generation_params
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -28,12 +29,13 @@ class HuggingFaceTGIChatGenerator:
|
||||
```python
|
||||
from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.utils import Secret
|
||||
|
||||
messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"),
|
||||
ChatMessage.from_user("What's Natural Language Processing?")]
|
||||
|
||||
|
||||
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token="<your-token>")
|
||||
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf", token=Secret.from_token("<your-api-key>"))
|
||||
client.warm_up()
|
||||
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
|
||||
print(response)
|
||||
@ -51,7 +53,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
|
||||
client = HuggingFaceTGIChatGenerator(model="meta-llama/Llama-2-70b-chat-hf",
|
||||
url="<your-tgi-endpoint-url>",
|
||||
token="<your-token>")
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
client.warm_up()
|
||||
response = client.run(messages, generation_kwargs={"max_new_tokens": 120})
|
||||
print(response)
|
||||
@ -85,7 +87,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
self,
|
||||
model: str = "meta-llama/Llama-2-13b-chat-hf",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
chat_template: Optional[str] = None,
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
@ -131,7 +133,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
self.chat_template = chat_template
|
||||
self.token = token
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.client = InferenceClient(url or model, token=token)
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.streaming_callback = streaming_callback
|
||||
self.tokenizer = None
|
||||
|
||||
@ -139,7 +141,9 @@ class HuggingFaceTGIChatGenerator:
|
||||
"""
|
||||
Load the tokenizer.
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model, token=self.token.resolve_value() if self.token else None
|
||||
)
|
||||
# mypy can't infer that chat_template attribute exists on the object returned by AutoTokenizer.from_pretrained
|
||||
chat_template = getattr(self.tokenizer, "chat_template", None)
|
||||
if not chat_template and not self.chat_template:
|
||||
@ -162,7 +166,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
chat_template=self.chat_template,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
)
|
||||
@ -172,6 +176,7 @@ class HuggingFaceTGIChatGenerator:
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
|
||||
@ -13,6 +13,7 @@ from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import StreamingChunk, ChatMessage
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -62,7 +63,7 @@ class OpenAIChatGenerator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
@ -73,8 +74,7 @@ class OpenAIChatGenerator:
|
||||
Creates an instance of OpenAIChatGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model: The name of the model to use.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function accepts StreamingChunk as an argument.
|
||||
@ -101,12 +101,13 @@ class OpenAIChatGenerator:
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.generation_kwargs = generation_kwargs or {}
|
||||
self.streaming_callback = streaming_callback
|
||||
self.api_base_url = api_base_url
|
||||
self.organization = organization
|
||||
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
|
||||
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -127,6 +128,7 @@ class OpenAIChatGenerator:
|
||||
api_base_url=self.api_base_url,
|
||||
organization=self.organization,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -136,6 +138,7 @@ class OpenAIChatGenerator:
|
||||
:param data: The dictionary representation of this component.
|
||||
:return: The deserialized component instance.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
@ -334,7 +337,7 @@ class OpenAIChatGenerator:
|
||||
class GPTChatGenerator(OpenAIChatGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union, Callable
|
||||
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient, HfApi
|
||||
@ -36,20 +37,20 @@ def check_generation_params(kwargs: Optional[Dict[str, Any]], additional_accepte
|
||||
)
|
||||
|
||||
|
||||
def check_valid_model(model_id: str, token: Optional[str]) -> None:
|
||||
def check_valid_model(model_id: str, token: Optional[Secret]) -> None:
|
||||
"""
|
||||
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
||||
Also check if the model is a text generation model.
|
||||
|
||||
:param model_id: A string representing the HuggingFace model ID.
|
||||
:param token: An optional string representing the authentication token.
|
||||
:param token: An optional authentication token.
|
||||
:raises ValueError: If the model is not found or is not a text generation model.
|
||||
"""
|
||||
transformers_import.check()
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
model_info = api.model_info(model_id, token=token)
|
||||
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
|
||||
except RepositoryNotFoundError as e:
|
||||
raise ValueError(
|
||||
f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -44,7 +45,7 @@ class HuggingFaceLocalGenerator:
|
||||
model: str = "google/flan-t5-base",
|
||||
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
huggingface_pipeline_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
@ -63,7 +64,6 @@ class HuggingFaceLocalGenerator:
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
||||
:param token: The token to use as HTTP bearer authorization for remote files.
|
||||
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
|
||||
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
||||
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
|
||||
Some examples: `max_length`, `max_new_tokens`, `temperature`, `top_k`, `top_p`,...
|
||||
@ -89,6 +89,9 @@ class HuggingFaceLocalGenerator:
|
||||
huggingface_pipeline_kwargs = huggingface_pipeline_kwargs or {}
|
||||
generation_kwargs = generation_kwargs or {}
|
||||
|
||||
self.token = token
|
||||
token = token.resolve_value() if token else None
|
||||
|
||||
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
||||
# otherwise, populate them with values from other init parameters
|
||||
huggingface_pipeline_kwargs.setdefault("model", model)
|
||||
@ -156,12 +159,11 @@ class HuggingFaceLocalGenerator:
|
||||
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
stop_words=self.stop_words,
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
)
|
||||
|
||||
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
||||
# we don't want to serialize valid tokens
|
||||
if isinstance(huggingface_pipeline_kwargs["token"], str):
|
||||
serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"].pop("token")
|
||||
huggingface_pipeline_kwargs.pop("token", None)
|
||||
|
||||
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
||||
return serialization_dict
|
||||
@ -171,6 +173,7 @@ class HuggingFaceLocalGenerator:
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from haystack.components.generators.utils import serialize_callback_handler, des
|
||||
from haystack.dataclasses import StreamingChunk
|
||||
from haystack.components.generators.hf_utils import check_generation_params, check_valid_model
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
with LazyImport(message="Run 'pip install transformers'") as transformers_import:
|
||||
from huggingface_hub import InferenceClient
|
||||
@ -29,7 +30,9 @@ class HuggingFaceTGIGenerator:
|
||||
|
||||
```python
|
||||
from haystack.components.generators import HuggingFaceTGIGenerator
|
||||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token="<your-token>")
|
||||
from haystack.utils import Secret
|
||||
|
||||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>"))
|
||||
client.warm_up()
|
||||
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
|
||||
print(response)
|
||||
@ -42,7 +45,7 @@ class HuggingFaceTGIGenerator:
|
||||
from haystack.components.generators import HuggingFaceTGIGenerator
|
||||
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1",
|
||||
url="<your-tgi-endpoint-url>",
|
||||
token="<your-token>")
|
||||
token=Secret.from_token("<your-api-key>"))
|
||||
client.warm_up()
|
||||
response = client.run("What's Natural Language Processing?")
|
||||
print(response)
|
||||
@ -74,7 +77,7 @@ class HuggingFaceTGIGenerator:
|
||||
self,
|
||||
model: str = "mistralai/Mistral-7B-v0.1",
|
||||
url: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
generation_kwargs: Optional[Dict[str, Any]] = None,
|
||||
stop_words: Optional[List[str]] = None,
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
@ -113,7 +116,7 @@ class HuggingFaceTGIGenerator:
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.generation_kwargs = generation_kwargs
|
||||
self.client = InferenceClient(url or model, token=token)
|
||||
self.client = InferenceClient(url or model, token=token.resolve_value() if token else None)
|
||||
self.streaming_callback = streaming_callback
|
||||
self.tokenizer = None
|
||||
|
||||
@ -121,7 +124,9 @@ class HuggingFaceTGIGenerator:
|
||||
"""
|
||||
Load the tokenizer
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, token=self.token)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model, token=self.token.resolve_value() if self.token else None
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -134,7 +139,7 @@ class HuggingFaceTGIGenerator:
|
||||
self,
|
||||
model=self.model,
|
||||
url=self.url,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
streaming_callback=callback_name,
|
||||
)
|
||||
@ -144,6 +149,7 @@ class HuggingFaceTGIGenerator:
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
|
||||
@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.components.generators.utils import serialize_callback_handler, deserialize_callback_handler
|
||||
from haystack.dataclasses import StreamingChunk, ChatMessage
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,7 +51,7 @@ class OpenAIGenerator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
@ -62,8 +63,7 @@ class OpenAIGenerator:
|
||||
Creates an instance of OpenAIGenerator. Unless specified otherwise in the `model`, this is for OpenAI's
|
||||
GPT-3.5 model.
|
||||
|
||||
:param api_key: The OpenAI API key. It can be explicitly provided or automatically read from the
|
||||
environment variable OPENAI_API_KEY (recommended).
|
||||
:param api_key: The OpenAI API key.
|
||||
:param model: The name of the model to use.
|
||||
:param streaming_callback: A callback function that is called when a new token is received from the stream.
|
||||
The callback function accepts StreamingChunk as an argument.
|
||||
@ -92,6 +92,7 @@ class OpenAIGenerator:
|
||||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
|
||||
values are the bias to add to that token.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.generation_kwargs = generation_kwargs or {}
|
||||
self.system_prompt = system_prompt
|
||||
@ -99,7 +100,7 @@ class OpenAIGenerator:
|
||||
|
||||
self.api_base_url = api_base_url
|
||||
self.organization = organization
|
||||
self.client = OpenAI(api_key=api_key, organization=organization, base_url=api_base_url)
|
||||
self.client = OpenAI(api_key=api_key.resolve_value(), organization=organization, base_url=api_base_url)
|
||||
|
||||
def _get_telemetry_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -120,6 +121,7 @@ class OpenAIGenerator:
|
||||
api_base_url=self.api_base_url,
|
||||
generation_kwargs=self.generation_kwargs,
|
||||
system_prompt=self.system_prompt,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -129,6 +131,7 @@ class OpenAIGenerator:
|
||||
:param data: The dictionary representation of this component.
|
||||
:return: The deserialized component instance.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
init_params = data.get("init_parameters", {})
|
||||
serialized_callback_handler = init_params.get("streaming_callback")
|
||||
if serialized_callback_handler:
|
||||
@ -279,7 +282,7 @@ class OpenAIGenerator:
|
||||
class GPTGenerator(OpenAIGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
|
||||
model: str = "gpt-3.5-turbo",
|
||||
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
||||
api_base_url: Optional[str] = None,
|
||||
|
||||
@ -6,6 +6,7 @@ from haystack import ComponentError, Document, component, default_from_dict, def
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, DeviceMap
|
||||
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs, resolve_hf_device_map
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,7 +41,7 @@ class TransformersSimilarityRanker:
|
||||
self,
|
||||
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
top_k: int = 10,
|
||||
query_prefix: str = "",
|
||||
document_prefix: str = "",
|
||||
@ -119,9 +120,11 @@ class TransformersSimilarityRanker:
|
||||
"""
|
||||
if self.model is None:
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
self.model_name_or_path, token=self.token, **self.model_kwargs
|
||||
self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)
|
||||
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
@ -132,7 +135,7 @@ class TransformersSimilarityRanker:
|
||||
self,
|
||||
device=None,
|
||||
model=self.model_name_or_path,
|
||||
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
top_k=self.top_k,
|
||||
query_prefix=self.query_prefix,
|
||||
document_prefix=self.document_prefix,
|
||||
@ -152,6 +155,7 @@ class TransformersSimilarityRanker:
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
init_params = data["init_parameters"]
|
||||
if init_params["device"] is not None:
|
||||
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from haystack import Document, component, default_to_dict, ComponentError
|
||||
from haystack import Document, component, default_to_dict, ComponentError, default_from_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,43 +27,48 @@ class SearchApiWebSearch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("SEARCHAPI_API_KEY"),
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
search_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param api_key: API key for the SearchApi API. It can be
|
||||
explicitly provided or automatically read from the
|
||||
environment variable SEARCHAPI_API_KEY (recommended).
|
||||
:param api_key: API key for the SearchApi API
|
||||
:param top_k: Number of documents to return.
|
||||
:param allowed_domains: List of domains to limit the search to.
|
||||
:param search_params: Additional parameters passed to the SearchApi API.
|
||||
For example, you can set 'num' to 100 to increase the number of search results.
|
||||
See the [SearchApi website](https://www.searchapi.io/) for more details.
|
||||
"""
|
||||
api_key = api_key or os.environ.get("SEARCHAPI_API_KEY")
|
||||
# we check whether api_key is None or an empty string
|
||||
if not api_key:
|
||||
msg = (
|
||||
"SearchApiWebSearch expects an API key. "
|
||||
"Set the SEARCHAPI_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
self.allowed_domains = allowed_domains
|
||||
self.search_params = search_params or {}
|
||||
|
||||
# Ensure that the API key is resolved.
|
||||
_ = self.api_key.resolve_value()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
|
||||
self,
|
||||
top_k=self.top_k,
|
||||
allowed_domains=self.allowed_domains,
|
||||
search_params=self.search_params,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SearchApiWebSearch":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document], links=List[str])
|
||||
def run(self, query: str):
|
||||
"""
|
||||
@ -74,7 +79,7 @@ class SearchApiWebSearch:
|
||||
query_prepend = "OR ".join(f"site:{domain} " for domain in self.allowed_domains) if self.allowed_domains else ""
|
||||
|
||||
payload = json.dumps({"q": query_prepend + " " + query, **self.search_params})
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "X-SearchApi-Source": "Haystack"}
|
||||
headers = {"Authorization": f"Bearer {self.api_key.resolve_value()}", "X-SearchApi-Source": "Haystack"}
|
||||
|
||||
try:
|
||||
response = requests.get(SEARCHAPI_BASE_URL, headers=headers, params=payload, timeout=90)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
from haystack import Document, component, default_to_dict, ComponentError
|
||||
from haystack import Document, component, default_to_dict, ComponentError, default_from_dict
|
||||
from haystack.utils import Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,43 +27,47 @@ class SerperDevWebSearch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_key: Secret = Secret.from_env_var("SERPERDEV_API_KEY"),
|
||||
top_k: Optional[int] = 10,
|
||||
allowed_domains: Optional[List[str]] = None,
|
||||
search_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
:param api_key: API key for the SerperDev API. It can be
|
||||
explicitly provided or automatically read from the
|
||||
environment variable SERPERDEV_API_KEY (recommended).
|
||||
:param api_key: API key for the SerperDev API.
|
||||
:param top_k: Number of documents to return.
|
||||
:param allowed_domains: List of domains to limit the search to.
|
||||
:param search_params: Additional parameters passed to the SerperDev API.
|
||||
For example, you can set 'num' to 20 to increase the number of search results.
|
||||
See the [Serper Dev website](https://serper.dev/) for more details.
|
||||
"""
|
||||
api_key = api_key or os.environ.get("SERPERDEV_API_KEY")
|
||||
# we check whether api_key is None or an empty string
|
||||
if not api_key:
|
||||
msg = (
|
||||
"SerperDevWebSearch expects an API key. "
|
||||
"Set the SERPERDEV_API_KEY environment variable (recommended) or pass it explicitly."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
self.api_key = api_key
|
||||
self.top_k = top_k
|
||||
self.allowed_domains = allowed_domains
|
||||
self.search_params = search_params or {}
|
||||
|
||||
# Ensure that the API key is resolved.
|
||||
_ = self.api_key.resolve_value()
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self, top_k=self.top_k, allowed_domains=self.allowed_domains, search_params=self.search_params
|
||||
self,
|
||||
top_k=self.top_k,
|
||||
allowed_domains=self.allowed_domains,
|
||||
search_params=self.search_params,
|
||||
api_key=self.api_key.to_dict(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SerperDevWebSearch":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document], links=List[str])
|
||||
def run(self, query: str):
|
||||
"""
|
||||
@ -76,10 +80,10 @@ class SerperDevWebSearch:
|
||||
payload = json.dumps(
|
||||
{"q": query_prepend + query, "gl": "us", "hl": "en", "autocorrect": True, **self.search_params}
|
||||
)
|
||||
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
|
||||
headers = {"X-API-KEY": self.api_key.resolve_value(), "Content-Type": "application/json"}
|
||||
|
||||
try:
|
||||
response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30)
|
||||
response = requests.post(SERPERDEV_BASE_URL, headers=headers, data=payload, timeout=30) # type: ignore
|
||||
response.raise_for_status() # Will raise an HTTPError for bad responses
|
||||
except requests.Timeout:
|
||||
raise TimeoutError(f"Request to {self.__class__.__name__} timed out.")
|
||||
|
||||
@ -2,4 +2,4 @@ from haystack.utils.expit import expit
|
||||
from haystack.utils.requests_utils import request_with_retry
|
||||
from haystack.utils.filters import document_matches_filter
|
||||
from haystack.utils.device import ComponentDevice, DeviceType, Device, DeviceMap
|
||||
from haystack.utils.auth import Secret
|
||||
from haystack.utils.auth import Secret, deserialize_secrets_inplace
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@ -193,3 +193,21 @@ class EnvVarSecret(Secret):
|
||||
if out is None and self._strict:
|
||||
raise ValueError(f"None of the following authentication environment variables are set: {self._env_vars}")
|
||||
return out
|
||||
|
||||
|
||||
def deserialize_secrets_inplace(data: Dict[str, Any], keys: Iterable[str], *, recursive: bool = False):
|
||||
"""
|
||||
Deserialize secrets in a dictionary inplace.
|
||||
|
||||
:param data:
|
||||
The dictionary with the serialized data.
|
||||
:param keys:
|
||||
The keys of the secrets to deserialize.
|
||||
:param recursive:
|
||||
Whether to recursively deserialize nested dictionaries.
|
||||
"""
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict) and recursive:
|
||||
deserialize_secrets_inplace(v, keys)
|
||||
elif k in keys and v is not None:
|
||||
data[k] = Secret.from_dict(v)
|
||||
|
||||
@ -4,3 +4,37 @@ features:
|
||||
Expose a `Secret` type to provide consistent API for any component that requires secrets for authentication.
|
||||
Currently supports string tokens and environment variables. Token-based secrets are automatically
|
||||
prevented from being serialized to disk (to prevent accidental leakage of secrets).
|
||||
```python
|
||||
@component
|
||||
class MyComponent:
|
||||
def __init__(self, api_key: Optional[Secret] = None, **kwargs):
|
||||
self.api_key = api_key
|
||||
self.backend = None
|
||||
|
||||
def warm_up(self):
|
||||
# Call resolve_value to yield a single result. The semantics of the result is policy-dependent.
|
||||
# Currently, all supported policies will return a single string token.
|
||||
self.backend = SomeBackend(api_key=self.api_key.resolve_value() if self.api_key else None, ...)
|
||||
|
||||
def to_dict(self):
|
||||
# Serialize the policy like any other (custom) data. If the policy is token-based, it will
|
||||
# raise an error.
|
||||
return default_to_dict(self, api_key=self.api_key.to_dict() if self.api_key else None, ...)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
# Deserialize the policy data before passing it to the generic from_dict function.
|
||||
api_key_data = data["init_parameters"]["api_key"]
|
||||
api_key = Secret.from_dict(api_key_data) if api_key_data is not None else None
|
||||
data["init_parameters"]["api_key"] = api_key
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
# No authentication.
|
||||
component = MyComponent(api_key=None)
|
||||
# Token based authentication
|
||||
component = MyComponent(api_key=Secret.from_token("sk-randomAPIkeyasdsa32ekasd32e"))
|
||||
component.to_dict() # Error! Can't serialize authentication tokens
|
||||
# Environment variable based authentication
|
||||
component = MyComponent(api_key=Secret.from_env("OPENAI_API_KEY"))
|
||||
component.to_dict() # This is fine
|
||||
```
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
---
|
||||
upgrade:
|
||||
- |
|
||||
Update secret handling for components using the `Secret` type. The following components are affected:
|
||||
`RemoteWhisperTranscriber`, `AzureOCRDocumentConverter`, `AzureOpenAIDocumentEmbedder`, `AzureOpenAITextEmbedder`, `HuggingFaceTEIDocumentEmbedder`, `HuggingFaceTEITextEmbedder`, `OpenAIDocumentEmbedder`, `SentenceTransformersDocumentEmbedder`, `SentenceTransformersTextEmbedder`, `AzureOpenAIGenerator`, `AzureOpenAIChatGenerator`, `HuggingFaceLocalChatGenerator`, `HuggingFaceTGIChatGenerator`, `OpenAIChatGenerator`, `HuggingFaceLocalGenerator`, `HuggingFaceTGIGenerator`, `OpenAIGenerator`, `TransformersSimilarityRanker`, `SearchApiWebSearch`, `SerperDevWebSearch`
|
||||
|
||||
The default init parameters for `api_key`, `token`, `azure_ad_token` have been adjusted to use environment variables wherever possible. The `azure_ad_token_provider` parameter has been removed from Azure-based components. Components based on Hugging
|
||||
Face are now required to either use a token or an environment variable if authentication is required - The on-disk local token file is no longer supported.
|
||||
@ -1,29 +1,29 @@
|
||||
import os
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack.components.audio.whisper_remote import RemoteWhisperTranscriber
|
||||
from haystack.dataclasses import ByteStream
|
||||
from haystack.utils import Secret
|
||||
|
||||
|
||||
class TestRemoteWhisperTranscriber:
|
||||
def test_init_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
RemoteWhisperTranscriber(api_key=None)
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
RemoteWhisperTranscriber()
|
||||
|
||||
def test_init_key_env_var(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
t = RemoteWhisperTranscriber(api_key=None)
|
||||
t = RemoteWhisperTranscriber()
|
||||
assert t.client.api_key == "test_api_key"
|
||||
|
||||
def test_init_key_module_env_and_global_var(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key_2")
|
||||
t = RemoteWhisperTranscriber(api_key=None)
|
||||
t = RemoteWhisperTranscriber()
|
||||
assert t.client.api_key == "test_api_key_2"
|
||||
|
||||
def test_init_default(self):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
|
||||
transcriber = RemoteWhisperTranscriber(api_key=Secret.from_token("test_api_key"))
|
||||
assert transcriber.client.api_key == "test_api_key"
|
||||
assert transcriber.model == "whisper-1"
|
||||
assert transcriber.organization is None
|
||||
@ -31,7 +31,7 @@ class TestRemoteWhisperTranscriber:
|
||||
|
||||
def test_init_custom_parameters(self):
|
||||
transcriber = RemoteWhisperTranscriber(
|
||||
api_key="test_api_key",
|
||||
api_key=Secret.from_token("test_api_key"),
|
||||
model="whisper-1",
|
||||
organization="test-org",
|
||||
api_base_url="test_api_url",
|
||||
@ -42,6 +42,7 @@ class TestRemoteWhisperTranscriber:
|
||||
)
|
||||
|
||||
assert transcriber.model == "whisper-1"
|
||||
assert transcriber.api_key == Secret.from_token("test_api_key")
|
||||
assert transcriber.organization == "test-org"
|
||||
assert transcriber.api_base_url == "test_api_url"
|
||||
assert transcriber.whisper_params == {
|
||||
@ -51,12 +52,14 @@ class TestRemoteWhisperTranscriber:
|
||||
"temperature": "0.5",
|
||||
}
|
||||
|
||||
def test_to_dict_default_parameters(self):
|
||||
transcriber = RemoteWhisperTranscriber(api_key="test_api_key")
|
||||
def test_to_dict_default_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
transcriber = RemoteWhisperTranscriber()
|
||||
data = transcriber.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "whisper-1",
|
||||
"api_base_url": None,
|
||||
"organization": None,
|
||||
@ -64,9 +67,10 @@ class TestRemoteWhisperTranscriber:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test_api_key")
|
||||
transcriber = RemoteWhisperTranscriber(
|
||||
api_key="test_api_key",
|
||||
api_key=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
model="whisper-1",
|
||||
organization="test-org",
|
||||
api_base_url="test_api_url",
|
||||
@ -79,6 +83,7 @@ class TestRemoteWhisperTranscriber:
|
||||
assert data == {
|
||||
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "whisper-1",
|
||||
"organization": "test-org",
|
||||
"api_base_url": "test_api_url",
|
||||
@ -142,6 +147,7 @@ class TestRemoteWhisperTranscriber:
|
||||
data = {
|
||||
"type": "haystack.components.audio.whisper_remote.RemoteWhisperTranscriber",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "whisper-1",
|
||||
"api_base_url": "https://api.openai.com/v1",
|
||||
"organization": None,
|
||||
@ -149,7 +155,7 @@ class TestRemoteWhisperTranscriber:
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
RemoteWhisperTranscriber.from_dict(data)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
|
||||
from haystack.components.converters.azure import AzureOCRDocumentConverter
|
||||
from haystack.dataclasses import ByteStream
|
||||
from haystack.utils import Secret
|
||||
|
||||
|
||||
class TestAzureOCRDocumentConverter:
|
||||
@ -13,15 +14,23 @@ class TestAzureOCRDocumentConverter:
|
||||
with pytest.raises(ValueError):
|
||||
AzureOCRDocumentConverter(endpoint="test_endpoint")
|
||||
|
||||
def test_to_dict(self):
|
||||
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_to_dict(self, mock_resolve_value):
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.converters.azure.AzureOCRDocumentConverter",
|
||||
"init_parameters": {"endpoint": "test_endpoint", "model_id": "prebuilt-read"},
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_AI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"endpoint": "test_endpoint",
|
||||
"model_id": "prebuilt-read",
|
||||
},
|
||||
}
|
||||
|
||||
def test_run(self, test_files_path):
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_run(self, mock_resolve_value, test_files_path):
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
with patch("haystack.components.converters.azure.DocumentAnalysisClient") as mock_azure_client:
|
||||
mock_result = Mock(pages=[Mock(lines=[Mock(content="mocked line 1"), Mock(content="mocked line 2")])])
|
||||
mock_result.to_dict.return_value = {
|
||||
@ -32,7 +41,7 @@ class TestAzureOCRDocumentConverter:
|
||||
}
|
||||
mock_azure_client.return_value.begin_analyze_document.return_value.result.return_value = mock_result
|
||||
|
||||
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
|
||||
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
|
||||
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
|
||||
document = output["documents"][0]
|
||||
assert document.content == "mocked line 1\nmocked line 2\n\f"
|
||||
@ -44,11 +53,12 @@ class TestAzureOCRDocumentConverter:
|
||||
"pages": [{"lines": [{"content": "mocked line 1"}, {"content": "mocked line 2"}]}],
|
||||
}
|
||||
|
||||
def test_run_with_meta(self, test_files_path):
|
||||
@patch("haystack.utils.auth.EnvVarSecret.resolve_value")
|
||||
def test_run_with_meta(self, mock_resolve_value, test_files_path):
|
||||
mock_resolve_value.return_value = "test_api_key"
|
||||
bytestream = ByteStream(data=b"test", meta={"author": "test_author", "language": "en"})
|
||||
with patch("haystack.components.converters.azure.DocumentAnalysisClient"):
|
||||
component = AzureOCRDocumentConverter(endpoint="test_endpoint", api_key="test_credential_key")
|
||||
|
||||
component = AzureOCRDocumentConverter(endpoint="test_endpoint")
|
||||
output = component.run(
|
||||
sources=[bytestream, test_files_path / "pdf" / "sample_pdf_1.pdf"], meta={"language": "it"}
|
||||
)
|
||||
@ -63,7 +73,7 @@ class TestAzureOCRDocumentConverter:
|
||||
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
|
||||
def test_run_with_pdf_file(self, test_files_path):
|
||||
component = AzureOCRDocumentConverter(
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"]
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
|
||||
)
|
||||
output = component.run(sources=[test_files_path / "pdf" / "sample_pdf_1.pdf"])
|
||||
documents = output["documents"]
|
||||
@ -77,7 +87,7 @@ class TestAzureOCRDocumentConverter:
|
||||
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
|
||||
def test_with_image_file(self, test_files_path):
|
||||
component = AzureOCRDocumentConverter(
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"]
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
|
||||
)
|
||||
output = component.run(sources=[test_files_path / "images" / "haystack-logo.png"])
|
||||
documents = output["documents"]
|
||||
@ -90,7 +100,7 @@ class TestAzureOCRDocumentConverter:
|
||||
@pytest.mark.skipif(not os.environ.get("CORE_AZURE_CS_API_KEY", None), reason="Azure credentials not available")
|
||||
def test_run_with_docx_file(self, test_files_path):
|
||||
component = AzureOCRDocumentConverter(
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=os.environ["CORE_AZURE_CS_API_KEY"]
|
||||
endpoint=os.environ["CORE_AZURE_CS_ENDPOINT"], api_key=Secret.from_env_var("CORE_AZURE_CS_API_KEY")
|
||||
)
|
||||
output = component.run(sources=[test_files_path / "docx" / "sample_docx.docx"])
|
||||
documents = output["documents"]
|
||||
|
||||
@ -19,14 +19,15 @@ class TestAzureOpenAIDocumentEmbedder:
|
||||
assert embedder.meta_fields_to_embed == []
|
||||
assert embedder.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict(self):
|
||||
component = AzureOpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
|
||||
)
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
component = AzureOpenAIDocumentEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.azure_document_embedder.AzureOpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"api_version": "2023-05-15",
|
||||
"azure_deployment": "text-embedding-ada-002",
|
||||
"azure_endpoint": "https://example-resource.azure.openai.com/",
|
||||
|
||||
@ -16,14 +16,15 @@ class TestAzureOpenAITextEmbedder:
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
|
||||
def test_to_dict(self):
|
||||
component = AzureOpenAITextEmbedder(
|
||||
api_key="fake-api-key", azure_endpoint="https://example-resource.azure.openai.com/"
|
||||
)
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake-api-key")
|
||||
component = AzureOpenAITextEmbedder(azure_endpoint="https://example-resource.azure.openai.com/")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.azure_text_embedder.AzureOpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"azure_deployment": "text-embedding-ada-002",
|
||||
"organization": None,
|
||||
"azure_endpoint": "https://example-resource.azure.openai.com/",
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
|
||||
from haystack.dataclasses import Document
|
||||
@ -29,7 +30,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == "fake-api-token"
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.batch_size == 32
|
||||
@ -41,7 +42,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
@ -52,7 +53,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == "fake-api-token"
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.batch_size == 64
|
||||
@ -78,6 +79,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
@ -92,7 +94,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
component = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
@ -106,6 +108,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_document_embedder.HuggingFaceTEIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
@ -125,7 +128,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator=" | ",
|
||||
)
|
||||
@ -146,7 +149,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="my_prefix ",
|
||||
suffix=" my_suffix",
|
||||
)
|
||||
@ -168,7 +171,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2)
|
||||
|
||||
@ -192,7 +197,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
meta_fields_to_embed=["topic"],
|
||||
@ -228,7 +233,7 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
meta_fields_to_embed=["topic"],
|
||||
@ -252,7 +257,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
|
||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
# wrong formats
|
||||
@ -267,7 +274,9 @@ class TestHuggingFaceTEIDocumentEmbedder:
|
||||
|
||||
def test_run_on_empty_list(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEIDocumentEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
empty_list_input = []
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
|
||||
|
||||
@ -27,7 +28,7 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
|
||||
assert embedder.model == "BAAI/bge-small-en-v1.5"
|
||||
assert embedder.url is None
|
||||
assert embedder.token == "fake-api-token"
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
|
||||
@ -35,14 +36,14 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
)
|
||||
|
||||
assert embedder.model == "sentence-transformers/all-mpnet-base-v2"
|
||||
assert embedder.url == "https://some_embedding_model.com"
|
||||
assert embedder.token == "fake-api-token"
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
|
||||
@ -62,14 +63,20 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {"model": "BAAI/bge-small-en-v1.5", "url": None, "prefix": "", "suffix": ""},
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"model": "BAAI/bge-small-en-v1.5",
|
||||
"url": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model):
|
||||
component = HuggingFaceTEITextEmbedder(
|
||||
model="sentence-transformers/all-mpnet-base-v2",
|
||||
url="https://some_embedding_model.com",
|
||||
token="fake-api-token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
)
|
||||
@ -79,6 +86,7 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.hugging_face_tei_text_embedder.HuggingFaceTEITextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "sentence-transformers/all-mpnet-base-v2",
|
||||
"url": "https://some_embedding_model.com",
|
||||
"prefix": "prefix",
|
||||
@ -91,7 +99,10 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
mock_embedding_patch.side_effect = mock_embedding_generation
|
||||
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", token="fake-api-token", prefix="prefix ", suffix=" suffix"
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix ",
|
||||
suffix=" suffix",
|
||||
)
|
||||
|
||||
result = embedder.run(text="The food was delicious")
|
||||
@ -103,7 +114,9 @@ class TestHuggingFaceTEITextEmbedder:
|
||||
|
||||
def test_run_wrong_input_format(self, mock_check_valid_model):
|
||||
embedder = HuggingFaceTEITextEmbedder(
|
||||
model="BAAI/bge-small-en-v1.5", url="https://some_embedding_model.com", token="fake-api-token"
|
||||
model="BAAI/bge-small-en-v1.5",
|
||||
url="https://some_embedding_model.com",
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import os
|
||||
from typing import List
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack import Document
|
||||
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
|
||||
@ -37,7 +37,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key",
|
||||
api_key=Secret.from_token("fake-api-key"),
|
||||
model="model",
|
||||
organization="my-org",
|
||||
prefix="prefix",
|
||||
@ -58,15 +58,17 @@ class TestOpenAIDocumentEmbedder:
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
OpenAIDocumentEmbedder()
|
||||
|
||||
def test_to_dict(self):
|
||||
component = OpenAIDocumentEmbedder(api_key="fake-api-key")
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
component = OpenAIDocumentEmbedder()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"api_base_url": None,
|
||||
"model": "text-embedding-ada-002",
|
||||
"organization": None,
|
||||
@ -79,9 +81,10 @@ class TestOpenAIDocumentEmbedder:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("ENV_VAR", "fake-api-key")
|
||||
component = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key",
|
||||
api_key=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
model="model",
|
||||
organization="my-org",
|
||||
prefix="prefix",
|
||||
@ -95,6 +98,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.openai_document_embedder.OpenAIDocumentEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"api_base_url": None,
|
||||
"model": "model",
|
||||
"organization": "my-org",
|
||||
@ -113,7 +117,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
]
|
||||
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key="fake-api-key", meta_fields_to_embed=["meta_field"], embedding_separator=" | "
|
||||
api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | "
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
@ -130,7 +134,9 @@ class TestOpenAIDocumentEmbedder:
|
||||
def test_prepare_texts_to_embed_w_suffix(self):
|
||||
documents = [Document(content=f"document number {i}") for i in range(5)]
|
||||
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix")
|
||||
embedder = OpenAIDocumentEmbedder(
|
||||
api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix"
|
||||
)
|
||||
|
||||
prepared_texts = embedder._prepare_texts_to_embed(documents)
|
||||
|
||||
@ -143,7 +149,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
]
|
||||
|
||||
def test_run_wrong_input_format(self):
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key")
|
||||
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
|
||||
|
||||
# wrong formats
|
||||
string_input = "text"
|
||||
@ -156,7 +162,7 @@ class TestOpenAIDocumentEmbedder:
|
||||
embedder.run(documents=list_integers_input)
|
||||
|
||||
def test_run_on_empty_list(self):
|
||||
embedder = OpenAIDocumentEmbedder(api_key="fake-api-key")
|
||||
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
|
||||
|
||||
empty_list_input = []
|
||||
result = embedder.run(documents=empty_list_input)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
|
||||
from haystack.components.embedders.openai_text_embedder import OpenAITextEmbedder
|
||||
|
||||
@ -19,7 +19,11 @@ class TestOpenAITextEmbedder:
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
embedder = OpenAITextEmbedder(
|
||||
api_key="fake-api-key", model="model", organization="fake-organization", prefix="prefix", suffix="suffix"
|
||||
api_key=Secret.from_token("fake-api-key"),
|
||||
model="model",
|
||||
organization="fake-organization",
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
)
|
||||
assert embedder.client.api_key == "fake-api-key"
|
||||
assert embedder.model == "model"
|
||||
@ -29,25 +33,38 @@ class TestOpenAITextEmbedder:
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
OpenAITextEmbedder()
|
||||
|
||||
def test_to_dict(self):
|
||||
component = OpenAITextEmbedder(api_key="fake-api-key")
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
component = OpenAITextEmbedder()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
|
||||
"init_parameters": {"model": "text-embedding-ada-002", "organization": None, "prefix": "", "suffix": ""},
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "text-embedding-ada-002",
|
||||
"organization": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("ENV_VAR", "fake-api-key")
|
||||
component = OpenAITextEmbedder(
|
||||
api_key="fake-api-key", model="model", organization="fake-organization", prefix="prefix", suffix="suffix"
|
||||
api_key=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
model="model",
|
||||
organization="fake-organization",
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.openai_text_embedder.OpenAITextEmbedder",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "model",
|
||||
"organization": "fake-organization",
|
||||
"prefix": "prefix",
|
||||
@ -56,7 +73,7 @@ class TestOpenAITextEmbedder:
|
||||
}
|
||||
|
||||
def test_run_wrong_input_format(self):
|
||||
embedder = OpenAITextEmbedder(api_key="fake-api-key")
|
||||
embedder = OpenAITextEmbedder(api_key=Secret.from_token("fake-api-key"))
|
||||
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
import numpy as np
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack import Document
|
||||
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
|
||||
@ -11,7 +12,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
embedder = SentenceTransformersDocumentEmbedder(model="model")
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.token is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.batch_size == 32
|
||||
@ -24,7 +25,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
embedder = SentenceTransformersDocumentEmbedder(
|
||||
model="model",
|
||||
device="cuda",
|
||||
token=True,
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
@ -35,7 +36,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
)
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == "cuda"
|
||||
assert embedder.token is True
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.batch_size == 64
|
||||
@ -52,7 +53,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"init_parameters": {
|
||||
"model": "model",
|
||||
"device": "cpu",
|
||||
"token": None,
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
@ -67,7 +68,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
component = SentenceTransformersDocumentEmbedder(
|
||||
model="model",
|
||||
device="cuda",
|
||||
token="the-token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
@ -83,7 +84,7 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"init_parameters": {
|
||||
"model": "model",
|
||||
"device": "cuda",
|
||||
"token": None, # the token is not serialized
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
@ -98,10 +99,10 @@ class TestSentenceTransformersDocumentEmbedder:
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_warmup(self, mocked_factory):
|
||||
embedder = SentenceTransformersDocumentEmbedder(model="model")
|
||||
embedder = SentenceTransformersDocumentEmbedder(model="model", token=None)
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", use_auth_token=None)
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
|
||||
@ -3,6 +3,7 @@ import pytest
|
||||
from haystack.components.embedders.backends.sentence_transformers_backend import (
|
||||
_SentenceTransformersEmbeddingBackendFactory,
|
||||
)
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||
@ -22,10 +23,10 @@ def test_factory_behavior(mock_sentence_transformer):
|
||||
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
|
||||
def test_model_initialization(mock_sentence_transformer):
|
||||
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
||||
model="model", device="cpu", use_auth_token="my_token"
|
||||
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token")
|
||||
)
|
||||
mock_sentence_transformer.assert_called_once_with(
|
||||
model_name_or_path="model", device="cpu", use_auth_token="my_token"
|
||||
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -11,7 +12,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
embedder = SentenceTransformersTextEmbedder(model="model")
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == "cpu"
|
||||
assert embedder.token is None
|
||||
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert embedder.prefix == ""
|
||||
assert embedder.suffix == ""
|
||||
assert embedder.batch_size == 32
|
||||
@ -22,7 +23,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
embedder = SentenceTransformersTextEmbedder(
|
||||
model="model",
|
||||
device="cuda",
|
||||
token=True,
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
@ -31,7 +32,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
)
|
||||
assert embedder.model == "model"
|
||||
assert embedder.device == "cuda"
|
||||
assert embedder.token is True
|
||||
assert embedder.token == Secret.from_token("fake-api-token")
|
||||
assert embedder.prefix == "prefix"
|
||||
assert embedder.suffix == "suffix"
|
||||
assert embedder.batch_size == 64
|
||||
@ -44,9 +45,9 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"model": "model",
|
||||
"device": "cpu",
|
||||
"token": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
@ -59,7 +60,7 @@ class TestSentenceTransformersTextEmbedder:
|
||||
component = SentenceTransformersTextEmbedder(
|
||||
model="model",
|
||||
device="cuda",
|
||||
token=True,
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
prefix="prefix",
|
||||
suffix="suffix",
|
||||
batch_size=64,
|
||||
@ -70,9 +71,9 @@ class TestSentenceTransformersTextEmbedder:
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"model": "model",
|
||||
"device": "cuda",
|
||||
"token": True,
|
||||
"prefix": "prefix",
|
||||
"suffix": "suffix",
|
||||
"batch_size": 64,
|
||||
@ -82,30 +83,18 @@ class TestSentenceTransformersTextEmbedder:
|
||||
}
|
||||
|
||||
def test_to_dict_not_serialize_token(self):
|
||||
component = SentenceTransformersTextEmbedder(model="model", token="awesome-token")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",
|
||||
"init_parameters": {
|
||||
"model": "model",
|
||||
"device": "cpu",
|
||||
"token": None,
|
||||
"prefix": "",
|
||||
"suffix": "",
|
||||
"batch_size": 32,
|
||||
"progress_bar": True,
|
||||
"normalize_embeddings": False,
|
||||
},
|
||||
}
|
||||
component = SentenceTransformersTextEmbedder(model="model", token=Secret.from_token("fake-api-token"))
|
||||
with pytest.raises(ValueError, match="Cannot serialize token-based secret"):
|
||||
component.to_dict()
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
)
|
||||
def test_warmup(self, mocked_factory):
|
||||
embedder = SentenceTransformersTextEmbedder(model="model")
|
||||
embedder = SentenceTransformersTextEmbedder(model="model", token=None)
|
||||
mocked_factory.get_embedding_backend.assert_not_called()
|
||||
embedder.warm_up()
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", use_auth_token=None)
|
||||
mocked_factory.get_embedding_backend.assert_called_once_with(model="model", device="cpu", auth_token=None)
|
||||
|
||||
@patch(
|
||||
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
|
||||
|
||||
@ -6,11 +6,13 @@ from openai import OpenAIError
|
||||
from haystack.components.generators.chat import AzureOpenAIChatGenerator
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
from haystack.dataclasses import ChatMessage
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
class TestOpenAIChatGenerator:
|
||||
def test_init_default(self):
|
||||
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint", api_key="test-api-key")
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.azure_deployment == "gpt-35-turbo"
|
||||
assert component.streaming_callback is None
|
||||
@ -18,13 +20,14 @@ class TestOpenAIChatGenerator:
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
component = AzureOpenAIChatGenerator(
|
||||
api_key=Secret.from_token("test-api-key"),
|
||||
azure_endpoint="some-non-existing-endpoint",
|
||||
api_key="test-api-key",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
@ -33,12 +36,15 @@ class TestOpenAIChatGenerator:
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
component = AzureOpenAIChatGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
|
||||
def test_to_dict_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
component = AzureOpenAIChatGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"api_version": "2023-05-15",
|
||||
"azure_endpoint": "some-non-existing-endpoint",
|
||||
"azure_deployment": "gpt-35-turbo",
|
||||
@ -48,9 +54,11 @@ class TestOpenAIChatGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
def test_to_dict_with_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("ENV_VAR", "test-api-key")
|
||||
component = AzureOpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
|
||||
azure_endpoint="some-non-existing-endpoint",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
@ -58,6 +66,8 @@ class TestOpenAIChatGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.azure.AzureOpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"},
|
||||
"api_version": "2023-05-15",
|
||||
"azure_endpoint": "some-non-existing-endpoint",
|
||||
"azure_deployment": "gpt-35-turbo",
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import patch, Mock
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -56,7 +57,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
token="test-token",
|
||||
token=Secret.from_token("test-token"),
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
)
|
||||
|
||||
@ -72,6 +73,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
task="text2text-generation",
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
token=None,
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
@ -82,7 +84,9 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
}
|
||||
|
||||
def test_init_task_parameter(self):
|
||||
generator = HuggingFaceLocalChatGenerator(task="text2text-generation", device=ComponentDevice.from_str("cpu"))
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
task="text2text-generation", device=ComponentDevice.from_str("cpu"), token=None
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "HuggingFaceH4/zephyr-7b-beta",
|
||||
@ -93,7 +97,9 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
|
||||
def test_init_task_in_huggingface_pipeline_kwargs(self):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
huggingface_pipeline_kwargs={"task": "text2text-generation"}, device=ComponentDevice.from_str("cpu")
|
||||
huggingface_pipeline_kwargs={"task": "text2text-generation"},
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
token=None,
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
@ -105,7 +111,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
|
||||
def test_init_task_inferred_from_model_name(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu")
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2", device=ComponentDevice.from_str("cpu"), token=None
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
@ -122,7 +128,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
def test_to_dict(self, model_info_mock):
|
||||
generator = HuggingFaceLocalChatGenerator(
|
||||
model="NousResearch/Llama-2-7b-chat-hf",
|
||||
token="token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=lambda x: x,
|
||||
@ -133,6 +139,7 @@ class TestHuggingFaceLocalChatGenerator:
|
||||
init_params = result["init_parameters"]
|
||||
|
||||
# Assert that the init_params dictionary contains the expected keys and values
|
||||
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
||||
assert "token" not in init_params["huggingface_pipeline_kwargs"]
|
||||
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
||||
@ -59,7 +60,7 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
# Initialize the HuggingFaceTGIChatGenerator object with valid parameters
|
||||
generator = HuggingFaceTGIChatGenerator(
|
||||
model="NousResearch/Llama-2-7b-chat-hf",
|
||||
token="token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=lambda x: x,
|
||||
@ -71,7 +72,7 @@ class TestHuggingFaceTGIChatGenerator:
|
||||
|
||||
# Assert that the init_params dictionary contains the expected keys and values
|
||||
assert init_params["model"] == "NousResearch/Llama-2-7b-chat-hf"
|
||||
assert init_params["token"] is None
|
||||
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
from haystack.components.generators.utils import print_streaming_chunk
|
||||
@ -17,8 +18,9 @@ def chat_messages():
|
||||
|
||||
|
||||
class TestOpenAIChatGenerator:
|
||||
def test_init_default(self):
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIChatGenerator()
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
@ -26,12 +28,12 @@ class TestOpenAIChatGenerator:
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
OpenAIChatGenerator()
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_token("test-api-key"),
|
||||
model="gpt-4",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
api_base_url="test-base-url",
|
||||
@ -42,12 +44,14 @@ class TestOpenAIChatGenerator:
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
def test_to_dict_default(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIChatGenerator()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-3.5-turbo",
|
||||
"organization": None,
|
||||
"streaming_callback": None,
|
||||
@ -56,9 +60,10 @@ class TestOpenAIChatGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
def test_to_dict_with_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("ENV_VAR", "test-api-key")
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_env_var("ENV_VAR"),
|
||||
model="gpt-4",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
api_base_url="test-base-url",
|
||||
@ -68,6 +73,7 @@ class TestOpenAIChatGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"organization": None,
|
||||
"api_base_url": "test-base-url",
|
||||
@ -76,9 +82,9 @@ class TestOpenAIChatGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIChatGenerator(
|
||||
api_key="test-api-key",
|
||||
model="gpt-4",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
@ -88,6 +94,7 @@ class TestOpenAIChatGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"organization": None,
|
||||
"api_base_url": "test-base-url",
|
||||
@ -96,10 +103,12 @@ class TestOpenAIChatGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "fake-api-key")
|
||||
data = {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
|
||||
@ -111,12 +120,14 @@ class TestOpenAIChatGenerator:
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
|
||||
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.generators.chat.openai.OpenAIChatGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"organization": None,
|
||||
"api_base_url": "test-base-url",
|
||||
@ -124,11 +135,11 @@ class TestOpenAIChatGenerator:
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
},
|
||||
}
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
OpenAIChatGenerator.from_dict(data)
|
||||
|
||||
def test_run(self, chat_messages, mock_chat_completion):
|
||||
component = OpenAIChatGenerator()
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
response = component.run(chat_messages)
|
||||
|
||||
# check that the component returns the correct ChatMessage response
|
||||
@ -139,7 +150,9 @@ class TestOpenAIChatGenerator:
|
||||
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
|
||||
|
||||
def test_run_with_params(self, chat_messages, mock_chat_completion):
|
||||
component = OpenAIChatGenerator(generation_kwargs={"max_tokens": 10, "temperature": 0.5})
|
||||
component = OpenAIChatGenerator(
|
||||
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
|
||||
)
|
||||
response = component.run(chat_messages)
|
||||
|
||||
# check that the component calls the OpenAI API with the correct parameters
|
||||
@ -161,7 +174,9 @@ class TestOpenAIChatGenerator:
|
||||
nonlocal streaming_callback_called
|
||||
streaming_callback_called = True
|
||||
|
||||
component = OpenAIChatGenerator(streaming_callback=streaming_callback)
|
||||
component = OpenAIChatGenerator(
|
||||
api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback
|
||||
)
|
||||
response = component.run(chat_messages)
|
||||
|
||||
# check we called the streaming callback
|
||||
@ -176,7 +191,7 @@ class TestOpenAIChatGenerator:
|
||||
assert "Hello" in response["replies"][0].content # see mock_chat_completion_chunk
|
||||
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
component = OpenAIChatGenerator(api_key="test-api-key")
|
||||
component = OpenAIChatGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
messages = [
|
||||
ChatMessage.from_assistant(
|
||||
"", meta={"finish_reason": "content_filter" if i % 2 == 0 else "length", "index": i}
|
||||
@ -208,7 +223,7 @@ class TestOpenAIChatGenerator:
|
||||
@pytest.mark.integration
|
||||
def test_live_run(self):
|
||||
chat_messages = [ChatMessage.from_user("What's the capital of France")]
|
||||
component = OpenAIChatGenerator(api_key=os.environ.get("OPENAI_API_KEY"), generation_kwargs={"n": 1})
|
||||
component = OpenAIChatGenerator(generation_kwargs={"n": 1})
|
||||
results = component.run(chat_messages)
|
||||
assert len(results["replies"]) == 1
|
||||
message: ChatMessage = results["replies"][0]
|
||||
@ -222,7 +237,7 @@ class TestOpenAIChatGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self, chat_messages):
|
||||
component = OpenAIChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIChatGenerator(model="something-obviously-wrong")
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run(chat_messages)
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
@ -8,8 +9,9 @@ from haystack.components.generators.utils import print_streaming_chunk
|
||||
|
||||
|
||||
class TestAzureOpenAIGenerator:
|
||||
def test_init_default(self):
|
||||
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
component = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.azure_deployment == "gpt-35-turbo"
|
||||
assert component.streaming_callback is None
|
||||
@ -17,28 +19,32 @@ class TestAzureOpenAIGenerator:
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("AZURE_OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AZURE_OPENAI_AD_TOKEN", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
component = AzureOpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_token("fake-api-key"),
|
||||
azure_endpoint="some-non-existing-endpoint",
|
||||
azure_deployment="gpt-35-turbo",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.client.api_key == "fake-api-key"
|
||||
assert component.azure_deployment == "gpt-35-turbo"
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
component = AzureOpenAIGenerator(api_key="test-api-key", azure_endpoint="some-non-existing-endpoint")
|
||||
def test_to_dict_default(self, monkeypatch):
|
||||
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")
|
||||
component = AzureOpenAIGenerator(azure_endpoint="some-non-existing-endpoint")
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.azure.AzureOpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["AZURE_OPENAI_API_KEY"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["AZURE_OPENAI_AD_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"azure_deployment": "gpt-35-turbo",
|
||||
"api_version": "2023-05-15",
|
||||
"streaming_callback": None,
|
||||
@ -49,9 +55,11 @@ class TestAzureOpenAIGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
def test_to_dict_with_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("ENV_VAR", "test-api-key")
|
||||
component = AzureOpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
azure_ad_token=Secret.from_env_var("ENV_VAR1", strict=False),
|
||||
azure_endpoint="some-non-existing-endpoint",
|
||||
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
|
||||
)
|
||||
@ -60,6 +68,8 @@ class TestAzureOpenAIGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.azure.AzureOpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"azure_ad_token": {"env_vars": ["ENV_VAR1"], "strict": False, "type": "env_var"},
|
||||
"azure_deployment": "gpt-35-turbo",
|
||||
"api_version": "2023-05-15",
|
||||
"streaming_callback": None,
|
||||
|
||||
@ -4,14 +4,16 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
from haystack.components.generators.hugging_face_local import HuggingFaceLocalGenerator, StopWordsCriteria
|
||||
from haystack.utils import ComponentDevice, Device
|
||||
from haystack.utils import ComponentDevice
|
||||
|
||||
|
||||
class TestHuggingFaceLocalGenerator:
|
||||
@patch("haystack.components.generators.hugging_face_local.model_info")
|
||||
def test_init_default(self, model_info_mock):
|
||||
def test_init_default(self, model_info_mock, monkeypatch):
|
||||
monkeypatch.delenv("HF_API_TOKEN", raising=False)
|
||||
model_info_mock.return_value.pipeline_tag = "text2text-generation"
|
||||
generator = HuggingFaceLocalGenerator()
|
||||
|
||||
@ -26,30 +28,33 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
def test_init_custom_token(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
model="google/flan-t5-base", task="text2text-generation", token=Secret.from_token("fake-api-token")
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": "test-token",
|
||||
"token": "fake-api-token",
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
}
|
||||
|
||||
def test_init_custom_device(self):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model="google/flan-t5-base", task="text2text-generation", device=ComponentDevice.from_str("cuda:0")
|
||||
model="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"token": "fake-api-token",
|
||||
"device": "cuda:0",
|
||||
}
|
||||
|
||||
def test_init_task_parameter(self):
|
||||
generator = HuggingFaceLocalGenerator(task="text2text-generation")
|
||||
generator = HuggingFaceLocalGenerator(task="text2text-generation", token=None)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
@ -59,7 +64,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
}
|
||||
|
||||
def test_init_task_in_huggingface_pipeline_kwargs(self):
|
||||
generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"})
|
||||
generator = HuggingFaceLocalGenerator(huggingface_pipeline_kwargs={"task": "text2text-generation"}, token=None)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
@ -71,7 +76,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
@patch("haystack.components.generators.hugging_face_local.model_info")
|
||||
def test_init_task_inferred_from_model_name(self, model_info_mock):
|
||||
model_info_mock.return_value.pipeline_tag = "text2text-generation"
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base")
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", token=None)
|
||||
|
||||
assert generator.huggingface_pipeline_kwargs == {
|
||||
"model": "google/flan-t5-base",
|
||||
@ -101,7 +106,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
model="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
device=ComponentDevice.from_str("cpu"),
|
||||
token="test-token",
|
||||
token=None,
|
||||
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs,
|
||||
)
|
||||
|
||||
@ -142,10 +147,10 @@ class TestHuggingFaceLocalGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "google/flan-t5-base",
|
||||
"task": "text2text-generation",
|
||||
"token": None,
|
||||
"device": ComponentDevice.resolve_device(None).to_hf(),
|
||||
},
|
||||
"generation_kwargs": {},
|
||||
@ -158,7 +163,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
model="gpt2",
|
||||
task="text-generation",
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
token="test-token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
stop_words=["coca", "cola"],
|
||||
huggingface_pipeline_kwargs={
|
||||
@ -175,6 +180,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
@ -197,7 +203,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
model="gpt2",
|
||||
task="text-generation",
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
token="test-token",
|
||||
token=None,
|
||||
generation_kwargs={"max_new_tokens": 100},
|
||||
stop_words=["coca", "cola"],
|
||||
huggingface_pipeline_kwargs={
|
||||
@ -216,6 +222,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"token": None,
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
@ -239,6 +246,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
data = {
|
||||
"type": "haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator",
|
||||
"init_parameters": {
|
||||
"token": None,
|
||||
"huggingface_pipeline_kwargs": {
|
||||
"model": "gpt2",
|
||||
"task": "text-generation",
|
||||
@ -276,9 +284,7 @@ class TestHuggingFaceLocalGenerator:
|
||||
|
||||
@patch("haystack.components.generators.hugging_face_local.pipeline")
|
||||
def test_warm_up(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
)
|
||||
generator = HuggingFaceLocalGenerator(model="google/flan-t5-base", task="text2text-generation", token=None)
|
||||
pipeline_mock.assert_not_called()
|
||||
|
||||
generator.warm_up()
|
||||
@ -286,14 +292,14 @@ class TestHuggingFaceLocalGenerator:
|
||||
pipeline_mock.assert_called_once_with(
|
||||
model="google/flan-t5-base",
|
||||
task="text2text-generation",
|
||||
token="test-token",
|
||||
token=None,
|
||||
device=ComponentDevice.resolve_device(None).to_hf(),
|
||||
)
|
||||
|
||||
@patch("haystack.components.generators.hugging_face_local.pipeline")
|
||||
def test_warm_up_doesn_reload(self, pipeline_mock):
|
||||
generator = HuggingFaceLocalGenerator(
|
||||
model="google/flan-t5-base", task="text2text-generation", token="test-token"
|
||||
model="google/flan-t5-base", task="text2text-generation", token=Secret.from_token("fake-api-token")
|
||||
)
|
||||
|
||||
pipeline_mock.assert_not_called()
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token, StreamDetails, FinishReason
|
||||
@ -59,7 +60,10 @@ class TestHuggingFaceTGIGenerator:
|
||||
def test_to_dict(self, mock_check_valid_model):
|
||||
# Initialize the HuggingFaceRemoteGenerator object with valid parameters
|
||||
generator = HuggingFaceTGIGenerator(
|
||||
token="token", generation_kwargs={"n": 5}, stop_words=["stop", "words"], streaming_callback=lambda x: x
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
generation_kwargs={"n": 5},
|
||||
stop_words=["stop", "words"],
|
||||
streaming_callback=lambda x: x,
|
||||
)
|
||||
|
||||
# Call the to_dict method
|
||||
@ -68,7 +72,7 @@ class TestHuggingFaceTGIGenerator:
|
||||
|
||||
# Assert that the init_params dictionary contains the expected keys and values
|
||||
assert init_params["model"] == "mistralai/Mistral-7B-v0.1"
|
||||
assert not init_params["token"]
|
||||
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||
assert init_params["generation_kwargs"] == {"n": 5, "stop_sequences": ["stop", "words"]}
|
||||
|
||||
def test_from_dict(self, mock_check_valid_model):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import List
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
@ -10,8 +11,9 @@ from haystack.dataclasses import StreamingChunk, ChatMessage
|
||||
|
||||
|
||||
class TestOpenAIGenerator:
|
||||
def test_init_default(self):
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIGenerator()
|
||||
assert component.client.api_key == "test-api-key"
|
||||
assert component.model == "gpt-3.5-turbo"
|
||||
assert component.streaming_callback is None
|
||||
@ -19,12 +21,12 @@ class TestOpenAIGenerator:
|
||||
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
OpenAIGenerator()
|
||||
|
||||
def test_init_with_parameters(self):
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_token("test-api-key"),
|
||||
model="gpt-4",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
api_base_url="test-base-url",
|
||||
@ -35,12 +37,14 @@ class TestOpenAIGenerator:
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
|
||||
def test_to_dict_default(self):
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
def test_to_dict_default(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIGenerator()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-3.5-turbo",
|
||||
"streaming_callback": None,
|
||||
"system_prompt": None,
|
||||
@ -49,9 +53,10 @@ class TestOpenAIGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_parameters(self):
|
||||
def test_to_dict_with_parameters(self, monkeypatch):
|
||||
monkeypatch.setenv("ENV_VAR", "test-api-key")
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
api_key=Secret.from_env_var("ENV_VAR"),
|
||||
model="gpt-4",
|
||||
streaming_callback=print_streaming_chunk,
|
||||
api_base_url="test-base-url",
|
||||
@ -61,6 +66,7 @@ class TestOpenAIGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["ENV_VAR"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
@ -69,9 +75,9 @@ class TestOpenAIGenerator:
|
||||
},
|
||||
}
|
||||
|
||||
def test_to_dict_with_lambda_streaming_callback(self):
|
||||
def test_to_dict_with_lambda_streaming_callback(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
|
||||
component = OpenAIGenerator(
|
||||
api_key="test-api-key",
|
||||
model="gpt-4",
|
||||
streaming_callback=lambda x: x,
|
||||
api_base_url="test-base-url",
|
||||
@ -81,6 +87,7 @@ class TestOpenAIGenerator:
|
||||
assert data == {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
@ -94,6 +101,7 @@ class TestOpenAIGenerator:
|
||||
data = {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"system_prompt": None,
|
||||
"api_base_url": "test-base-url",
|
||||
@ -106,23 +114,25 @@ class TestOpenAIGenerator:
|
||||
assert component.streaming_callback is print_streaming_chunk
|
||||
assert component.api_base_url == "test-base-url"
|
||||
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
|
||||
assert component.api_key == Secret.from_env_var("OPENAI_API_KEY")
|
||||
|
||||
def test_from_dict_fail_wo_env_var(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
data = {
|
||||
"type": "haystack.components.generators.openai.OpenAIGenerator",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["OPENAI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"model": "gpt-4",
|
||||
"api_base_url": "test-base-url",
|
||||
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
|
||||
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
|
||||
},
|
||||
}
|
||||
with pytest.raises(OpenAIError):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
OpenAIGenerator.from_dict(data)
|
||||
|
||||
def test_run(self, mock_chat_completion):
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
response = component.run("What's Natural Language Processing?")
|
||||
|
||||
# check that the component returns the correct ChatMessage response
|
||||
@ -139,7 +149,7 @@ class TestOpenAIGenerator:
|
||||
nonlocal streaming_callback_called
|
||||
streaming_callback_called = True
|
||||
|
||||
component = OpenAIGenerator(streaming_callback=streaming_callback)
|
||||
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback)
|
||||
response = component.run("Come on, stream!")
|
||||
|
||||
# check we called the streaming callback
|
||||
@ -153,7 +163,9 @@ class TestOpenAIGenerator:
|
||||
assert "Hello" in response["replies"][0] # see mock_chat_completion_chunk
|
||||
|
||||
def test_run_with_params(self, mock_chat_completion):
|
||||
component = OpenAIGenerator(api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5})
|
||||
component = OpenAIGenerator(
|
||||
api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5}
|
||||
)
|
||||
response = component.run("What's Natural Language Processing?")
|
||||
|
||||
# check that the component calls the OpenAI API with the correct parameters
|
||||
@ -169,7 +181,7 @@ class TestOpenAIGenerator:
|
||||
assert [isinstance(reply, str) for reply in response["replies"]]
|
||||
|
||||
def test_check_abnormal_completions(self, caplog):
|
||||
component = OpenAIGenerator(api_key="test-api-key")
|
||||
component = OpenAIGenerator(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
# underlying implementation uses ChatMessage objects so we have to use them here
|
||||
messages: List[ChatMessage] = []
|
||||
@ -202,7 +214,7 @@ class TestOpenAIGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run(self):
|
||||
component = OpenAIGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIGenerator()
|
||||
results = component.run("What's the capital of France?")
|
||||
assert len(results["replies"]) == 1
|
||||
assert len(results["meta"]) == 1
|
||||
@ -224,7 +236,7 @@ class TestOpenAIGenerator:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_live_run_wrong_model(self):
|
||||
component = OpenAIGenerator(model="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
component = OpenAIGenerator(model="something-obviously-wrong")
|
||||
with pytest.raises(OpenAIError):
|
||||
component.run("Whatever")
|
||||
|
||||
@ -244,7 +256,7 @@ class TestOpenAIGenerator:
|
||||
self.responses += chunk.content if chunk.content else ""
|
||||
|
||||
callback = Callback()
|
||||
component = OpenAIGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
|
||||
component = OpenAIGenerator(streaming_callback=callback)
|
||||
results = component.run("What's the capital of France?")
|
||||
|
||||
assert len(results["replies"]) == 1
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
import logging
|
||||
@ -19,7 +20,7 @@ class TestSimilarityRanker:
|
||||
"init_parameters": {
|
||||
"device": None,
|
||||
"top_k": 10,
|
||||
"token": None,
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"query_prefix": "",
|
||||
"document_prefix": "",
|
||||
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
@ -36,7 +37,7 @@ class TestSimilarityRanker:
|
||||
component = TransformersSimilarityRanker(
|
||||
model="my_model",
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
token="my_token",
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
top_k=5,
|
||||
query_prefix="query_instruction: ",
|
||||
document_prefix="document_instruction: ",
|
||||
@ -51,7 +52,7 @@ class TestSimilarityRanker:
|
||||
"init_parameters": {
|
||||
"device": None,
|
||||
"model": "my_model",
|
||||
"token": None, # we don't serialize valid tokens,
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"top_k": 5,
|
||||
"query_prefix": "query_instruction: ",
|
||||
"document_prefix": "document_instruction: ",
|
||||
@ -84,7 +85,7 @@ class TestSimilarityRanker:
|
||||
"top_k": 10,
|
||||
"query_prefix": "",
|
||||
"document_prefix": "",
|
||||
"token": None,
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"model": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
@ -110,7 +111,7 @@ class TestSimilarityRanker:
|
||||
],
|
||||
)
|
||||
def test_to_dict_device_map(self, device_map, expected):
|
||||
component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map})
|
||||
component = TransformersSimilarityRanker(model_kwargs={"device_map": device_map}, token=None)
|
||||
data = component.to_dict()
|
||||
|
||||
assert data == {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from requests import Timeout, RequestException, HTTPError
|
||||
@ -369,17 +370,19 @@ def mock_searchapi_search_result():
|
||||
class TestSearchApiSearchAPI:
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("SEARCHAPI_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="SearchApiWebSearch expects an API key"):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
SearchApiWebSearch()
|
||||
|
||||
def test_to_dict(self):
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("SEARCHAPI_API_KEY", "test-api-key")
|
||||
component = SearchApiWebSearch(
|
||||
api_key="api_key", top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"}
|
||||
top_k=10, allowed_domains=["testdomain.com"], search_params={"param": "test params"}
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.websearch.searchapi.SearchApiWebSearch",
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["SEARCHAPI_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"top_k": 10,
|
||||
"allowed_domains": ["testdomain.com"],
|
||||
"search_params": {"param": "test params"},
|
||||
@ -388,7 +391,7 @@ class TestSearchApiSearchAPI:
|
||||
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 7])
|
||||
def test_web_search_top_k(self, mock_searchapi_search_result, top_k: int):
|
||||
ws = SearchApiWebSearch(api_key="api_key", top_k=top_k)
|
||||
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"), top_k=top_k)
|
||||
results = ws.run(query="Who is CEO of Microsoft?")
|
||||
documents = results["documents"]
|
||||
links = results["links"]
|
||||
@ -400,7 +403,7 @@ class TestSearchApiSearchAPI:
|
||||
@patch("requests.get")
|
||||
def test_timeout_error(self, mock_get):
|
||||
mock_get.side_effect = Timeout
|
||||
ws = SearchApiWebSearch(api_key="api_key")
|
||||
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
ws.run(query="Who is CEO of Microsoft?")
|
||||
@ -408,7 +411,7 @@ class TestSearchApiSearchAPI:
|
||||
@patch("requests.get")
|
||||
def test_request_exception(self, mock_get):
|
||||
mock_get.side_effect = RequestException
|
||||
ws = SearchApiWebSearch(api_key="api_key")
|
||||
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
with pytest.raises(SearchApiError):
|
||||
ws.run(query="Who is CEO of Microsoft?")
|
||||
@ -418,7 +421,7 @@ class TestSearchApiSearchAPI:
|
||||
mock_response = mock_get.return_value
|
||||
mock_response.status_code = 404
|
||||
mock_response.raise_for_status.side_effect = HTTPError
|
||||
ws = SearchApiWebSearch(api_key="api_key")
|
||||
ws = SearchApiWebSearch(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
with pytest.raises(SearchApiError):
|
||||
ws.run(query="Who is CEO of Microsoft?")
|
||||
@ -429,7 +432,7 @@ class TestSearchApiSearchAPI:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_web_search(self):
|
||||
ws = SearchApiWebSearch(api_key=os.environ.get("SEARCHAPI_API_KEY", None), top_k=10)
|
||||
ws = SearchApiWebSearch(top_k=10)
|
||||
results = ws.run(query="Who is CEO of Microsoft?")
|
||||
documents = results["documents"]
|
||||
links = results["links"]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
import pytest
|
||||
from requests import Timeout, RequestException, HTTPError
|
||||
@ -110,22 +111,26 @@ def mock_serper_dev_search_result():
|
||||
class TestSerperDevSearchAPI:
|
||||
def test_init_fail_wo_api_key(self, monkeypatch):
|
||||
monkeypatch.delenv("SERPERDEV_API_KEY", raising=False)
|
||||
with pytest.raises(ValueError, match="SerperDevWebSearch expects an API key"):
|
||||
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
|
||||
SerperDevWebSearch()
|
||||
|
||||
def test_to_dict(self):
|
||||
component = SerperDevWebSearch(
|
||||
api_key="test_key", top_k=10, allowed_domains=["test.com"], search_params={"param": "test"}
|
||||
)
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("SERPERDEV_API_KEY", "test-api-key")
|
||||
component = SerperDevWebSearch(top_k=10, allowed_domains=["test.com"], search_params={"param": "test"})
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.websearch.serper_dev.SerperDevWebSearch",
|
||||
"init_parameters": {"top_k": 10, "allowed_domains": ["test.com"], "search_params": {"param": "test"}},
|
||||
"init_parameters": {
|
||||
"api_key": {"env_vars": ["SERPERDEV_API_KEY"], "strict": True, "type": "env_var"},
|
||||
"top_k": 10,
|
||||
"allowed_domains": ["test.com"],
|
||||
"search_params": {"param": "test"},
|
||||
},
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize("top_k", [1, 5, 7])
|
||||
def test_web_search_top_k(self, mock_serper_dev_search_result, top_k: int):
|
||||
ws = SerperDevWebSearch(api_key="some_invalid_key", top_k=top_k)
|
||||
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"), top_k=top_k)
|
||||
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
documents = results["documents"]
|
||||
links = results["links"]
|
||||
@ -137,7 +142,7 @@ class TestSerperDevSearchAPI:
|
||||
@patch("requests.post")
|
||||
def test_timeout_error(self, mock_post):
|
||||
mock_post.side_effect = Timeout
|
||||
ws = SerperDevWebSearch(api_key="some_invalid_key")
|
||||
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
with pytest.raises(TimeoutError):
|
||||
ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
@ -145,7 +150,7 @@ class TestSerperDevSearchAPI:
|
||||
@patch("requests.post")
|
||||
def test_request_exception(self, mock_post):
|
||||
mock_post.side_effect = RequestException
|
||||
ws = SerperDevWebSearch(api_key="some_invalid_key")
|
||||
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
with pytest.raises(SerperDevError):
|
||||
ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
@ -155,7 +160,7 @@ class TestSerperDevSearchAPI:
|
||||
mock_response = mock_post.return_value
|
||||
mock_response.status_code = 404
|
||||
mock_response.raise_for_status.side_effect = HTTPError
|
||||
ws = SerperDevWebSearch(api_key="some_invalid_key")
|
||||
ws = SerperDevWebSearch(api_key=Secret.from_token("test-api-key"))
|
||||
|
||||
with pytest.raises(SerperDevError):
|
||||
ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
@ -166,7 +171,7 @@ class TestSerperDevSearchAPI:
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_web_search(self):
|
||||
ws = SerperDevWebSearch(api_key=os.environ.get("SERPERDEV_API_KEY", None), top_k=10)
|
||||
ws = SerperDevWebSearch(top_k=10)
|
||||
results = ws.run(query="Who is the boyfriend of Olivia Wilde?")
|
||||
documents = results["documents"]
|
||||
links = results["documents"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user