From f998bf4a4fc02b9f4d847bef065bb3d1666fb2cf Mon Sep 17 00:00:00 2001 From: Vivek Silimkhan <126159777+viveksilimkhan1@users.noreply.github.com> Date: Wed, 15 Nov 2023 17:56:29 +0530 Subject: [PATCH] feat: add Amazon Bedrock support (#6226) * Add Bedrock * Update supported models for Bedrock * Fix supports and add extract response in Bedrock * fix errors imports * improve and refactor supports * fix install * fix mypy * fix pylint * fix existing tests * Added Anthropic Bedrock * fix tests * fix sagemaker tests * add default prompt handler, constructor and supports tests * more tests * invoke refactoring * refactor model_kwargs * fix mypy * lstrip responses * Add streaming support * bump boto3 version * add class docstrings, better exception names * fix layer name * add tests for anthropic and cohere model adapters * update cohere params * update ai21 args and add tests * support cohere command light model * add tital tests * better class names * support meta llama 2 model * fix streaming support * more future-proof model adapter selection * fix import * fix mypy * fix pylint for preview * add tests for streaming * add release notes * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * fix format * fix tests after msg changes * fix streaming for cohere --------- Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com> Co-authored-by: tstadel Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> --- .github/workflows/linting_preview.yml | 5 +- haystack/errors.py | 21 + .../nodes/prompt/invocation_layer/__init__.py | 1 + .../prompt/invocation_layer/amazon_bedrock.py | 372 ++++++ .../nodes/prompt/invocation_layer/aws_base.py | 79 ++ .../prompt/invocation_layer/sagemaker_base.py | 74 +- .../invocation_layer/sagemaker_hf_infer.py | 2 +- .../invocation_layer/sagemaker_hf_text_gen.py | 2 +- .../prompt/invocation_layer/sagemaker_meta.py | 29 +- pyproject.toml | 8 +- .../bedrock-support-bce28e3078c85c12.yaml | 16 + .../invocation_layer/test_amazon_bedrock.py | 1033 +++++++++++++++++ .../test_invocation_layers.py | 4 +- .../test_sagemaker_hf_infer.py | 8 +- .../test_sagemaker_hf_text_gen.py | 11 +- .../invocation_layer/test_sagemaker_meta.py | 15 +- 16 files changed, 1564 insertions(+), 116 deletions(-) create mode 100644 haystack/nodes/prompt/invocation_layer/amazon_bedrock.py create mode 100644 haystack/nodes/prompt/invocation_layer/aws_base.py create mode 100644 releasenotes/notes/bedrock-support-bce28e3078c85c12.yaml create mode 100644 test/prompt/invocation_layer/test_amazon_bedrock.py diff --git a/.github/workflows/linting_preview.yml b/.github/workflows/linting_preview.yml index c69aeb91e..b4138d227 100644 --- a/.github/workflows/linting_preview.yml +++ b/.github/workflows/linting_preview.yml @@ -63,10 +63,7 @@ jobs: uses: tj-actions/changed-files@v40 with: files: | - **/*.py - files_ignore: | - test/** - rest_api/test/** + haystack/preview/**/*.py - uses: actions/setup-python@v4 with: diff --git a/haystack/errors.py b/haystack/errors.py index a9c1b7be8..f2b27aa31 100644 --- a/haystack/errors.py +++ b/haystack/errors.py @@ -202,6 +202,27 @@ class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError): """Exception for issues that occur in the HuggingFace inference node due to unauthorized access""" +class AWSConfigurationError(NodeError): + """Exception raised when AWS is not configured correctly""" + + def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): + super().__init__(message=message, send_message_in_event=send_message_in_event) + + +class AmazonBedrockConfigurationError(NodeError): + """Exception raised when AmazonBedrock node is not configured correctly""" + + def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): + super().__init__(message=message, send_message_in_event=send_message_in_event) + + +class AmazonBedrockInferenceError(NodeError): + """Exception for issues that occur in the Bedrock inference node""" + + def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False): + super().__init__(message=message, send_message_in_event=send_message_in_event) + + class SageMakerInferenceError(NodeError): """Exception for issues that occur in the SageMaker inference node""" diff --git a/haystack/nodes/prompt/invocation_layer/__init__.py b/haystack/nodes/prompt/invocation_layer/__init__.py index f594bc0fc..778d557c1 100644 --- a/haystack/nodes/prompt/invocation_layer/__init__.py +++ b/haystack/nodes/prompt/invocation_layer/__init__.py @@ -8,6 +8,7 @@ from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicCla from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer +from haystack.nodes.prompt.invocation_layer.amazon_bedrock import AmazonBedrockInvocationLayer from haystack.nodes.prompt.invocation_layer.sagemaker_meta import SageMakerMetaInvocationLayer from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer diff --git a/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py b/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py new file mode 100644 index 000000000..51f88b14c --- /dev/null +++ b/haystack/nodes/prompt/invocation_layer/amazon_bedrock.py @@ -0,0 +1,372 @@ +from abc import ABC, abstractmethod +import json +import logging +import re +from typing import Any, Optional, Dict, Type, Union, List + +from haystack.errors import AWSConfigurationError, AmazonBedrockConfigurationError, AmazonBedrockInferenceError +from haystack.lazy_imports import LazyImport +from haystack.nodes.prompt.invocation_layer.aws_base import AWSBaseInvocationLayer +from haystack.nodes.prompt.invocation_layer.handlers import ( + DefaultPromptHandler, + DefaultTokenStreamingHandler, + TokenStreamingHandler, +) + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import: + from botocore.exceptions import ClientError + + +class BedrockModelAdapter(ABC): + """ + Base class for Amazon Bedrock model adapters. + """ + + def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None: + self.model_kwargs = model_kwargs + self.max_length = max_length + + @abstractmethod + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + """Prepares the body for the Amazon Bedrock request.""" + + def get_responses(self, response_body: Dict[str, Any]) -> List[str]: + """Extracts the responses from the Amazon Bedrock response.""" + completions = self._extract_completions_from_response(response_body) + responses = [completion.lstrip() for completion in completions] + return responses + + def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]: + tokens: List[str] = [] + for event in stream: + chunk = event.get("chunk") + if chunk: + decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + token = self._extract_token_from_stream(decoded_chunk) + tokens.append(stream_handler(token, event_data=decoded_chunk)) + responses = ["".join(tokens).lstrip()] + return responses + + def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges the default params with the inference kwargs and model kwargs. + + Includes param if it's in kwargs or its default is not None (i.e. it is actually defined). + """ + kwargs = self.model_kwargs.copy() + kwargs.update(inference_kwargs) + return { + param: kwargs.get(param, default) + for param, default in default_params.items() + if param in kwargs or default is not None + } + + @abstractmethod + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + """Extracts the responses from the Amazon Bedrock response.""" + + @abstractmethod + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + """Extracts the token from a streaming chunk.""" + + +class AnthropicClaudeAdapter(BedrockModelAdapter): + """ + Model adapter for the Anthropic's Claude model. + """ + + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + default_params = { + "max_tokens_to_sample": self.max_length, + "stop_sequences": ["\n\nHuman:"], + "temperature": None, + "top_p": None, + "top_k": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + return [response_body["completion"]] + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("completion", "") + + +class CohereCommandAdapter(BedrockModelAdapter): + """ + Model adapter for the Cohere's Command model. + """ + + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + default_params = { + "max_tokens": self.max_length, + "stop_sequences": None, + "temperature": None, + "p": None, + "k": None, + "return_likelihoods": None, + "stream": None, + "logit_bias": None, + "num_generations": None, + "truncate": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": prompt, **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + responses = [generation["text"] for generation in response_body["generations"]] + return responses + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("text", "") + + +class AI21LabsJurassic2Adapter(BedrockModelAdapter): + """ + Model adapter for AI21 Labs' Jurassic 2 models. + """ + + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + default_params = { + "maxTokens": self.max_length, + "stopSequences": None, + "temperature": None, + "topP": None, + "countPenalty": None, + "presencePenalty": None, + "frequencyPenalty": None, + "numResults": None, + } + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": prompt, **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + responses = [completion["data"]["text"] for completion in response_body["completions"]] + return responses + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + raise NotImplementedError("Streaming is not supported for AI21 Jurassic 2 models.") + + +class AmazonTitanAdapter(BedrockModelAdapter): + """ + Model adapter for Amazon's Titan models. + """ + + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None} + params = self._get_params(inference_kwargs, default_params) + + body = {"inputText": prompt, "textGenerationConfig": params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + responses = [result["outputText"] for result in response_body["results"]] + return responses + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("outputText", "") + + +class MetaLlama2ChatAdapter(BedrockModelAdapter): + """ + Model adapter for Meta's Llama 2 Chat models. + """ + + def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]: + default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None} + params = self._get_params(inference_kwargs, default_params) + + body = {"prompt": prompt, **params} + return body + + def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]: + return [response_body["generation"]] + + def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: + return chunk.get("generation", "") + + +class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer): + """ + Invocation layer for Amazon Bedrock models. + """ + + SUPPORTED_MODEL_PATTERNS: Dict[str, Type[BedrockModelAdapter]] = { + r"amazon.titan-text.*": AmazonTitanAdapter, + r"ai21.j2.*": AI21LabsJurassic2Adapter, + r"cohere.command.*": CohereCommandAdapter, + r"anthropic.claude.*": AnthropicClaudeAdapter, + r"meta.llama2.*": MetaLlama2ChatAdapter, + } + + def __init__( + self, + model_name_or_path: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + max_length: Optional[int] = 100, + **kwargs, + ): + super().__init__(model_name_or_path, **kwargs) + self.max_length = max_length + + try: + session = self.get_aws_session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_profile_name=aws_profile_name, + ) + self.client = session.client("bedrock-runtime") + except Exception as exception: + raise AmazonBedrockConfigurationError( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) from exception + + model_input_kwargs = kwargs + # We pop the model_max_length as it is not sent to the model + # but used to truncate the prompt if needed + model_max_length = kwargs.get("model_max_length", 4096) + + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # It is hard to determine which tokenizer to use for the SageMaker model + # so we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( + model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100 + ) + + model_apapter_cls = self.get_model_adapter(model_name_or_path=model_name_or_path) + if not model_apapter_cls: + raise AmazonBedrockConfigurationError( + f"This invocation layer doesn't support the model {model_name_or_path}." + ) + self.model_adapter = model_apapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length) + + def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: + # the prompt for this model will be of the type str + if isinstance(prompt, List): + raise ValueError( + "The SageMaker invocation layer only supports a string as a prompt, " + "while currently, the prompt is a dictionary." + ) + + resize_info = self.prompt_handler(prompt) + if resize_info["prompt_length"] != resize_info["new_prompt_length"]: + logger.warning( + "The prompt was truncated from %s tokens to %s tokens so that the prompt length and " + "the answer length (%s tokens) fit within the model's max token limit (%s tokens). " + "Shorten the prompt or it will be cut off.", + resize_info["prompt_length"], + max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore + resize_info["max_length"], + resize_info["model_max_length"], + ) + return str(resize_info["resized_prompt"]) + + @classmethod + def supports(cls, model_name_or_path, **kwargs): + model_supported = cls.get_model_adapter(model_name_or_path) is not None + if not model_supported or not cls.aws_configured(**kwargs): + return False + + try: + session = cls.get_aws_session(**kwargs) + bedrock = session.client("bedrock") + foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT") + available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])] + model_ids_supporting_streaming = [ + entry["modelId"] + for entry in foundation_models_response.get("modelSummaries", []) + if entry.get("responseStreamingSupported", False) + ] + except AWSConfigurationError as exception: + raise AmazonBedrockConfigurationError(message=exception.message) from exception + except Exception as exception: + raise AmazonBedrockConfigurationError( + "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) from exception + + model_available = model_name_or_path in available_model_ids + if not model_available: + raise AmazonBedrockConfigurationError( + f"The model {model_name_or_path} is not available in Amazon Bedrock. " + f"Make sure the model you want to use is available in the configured AWS region and you have access." + ) + + stream: bool = kwargs.get("stream", False) + model_supports_streaming = model_name_or_path in model_ids_supporting_streaming + if stream and not model_supports_streaming: + raise AmazonBedrockConfigurationError( + f"The model {model_name_or_path} doesn't support streaming. Remove the `stream` parameter." + ) + + return model_supported + + def invoke(self, *args, **kwargs): + kwargs = kwargs.copy() + prompt: str = kwargs.pop("prompt", None) + stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False)) + + if not prompt or not isinstance(prompt, (str, list)): + raise ValueError( + f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. " + f"Make sure to provide a prompt in the format that the model expects." + ) + + body = self.model_adapter.prepare_body(prompt=prompt, **kwargs) + try: + if stream: + response = self.client.invoke_model_with_response_stream( + body=json.dumps(body), + modelId=self.model_name_or_path, + accept="application/json", + contentType="application/json", + ) + response_stream = response["body"] + handler: TokenStreamingHandler = kwargs.get( + "stream_handler", + self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()), + ) + responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler) + else: + response = self.client.invoke_model( + body=json.dumps(body), + modelId=self.model_name_or_path, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read().decode("utf-8")) + responses = self.model_adapter.get_responses(response_body=response_body) + except ClientError as exception: + raise AmazonBedrockInferenceError( + f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. " + "Make sure your AWS environment is configured correctly, " + "the model is available in the configured AWS region, and you have access." + ) from exception + + return responses + + @classmethod + def get_model_adapter(cls, model_name_or_path: str) -> Optional[Type[BedrockModelAdapter]]: + for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items(): + if re.fullmatch(pattern, model_name_or_path): + return adapter + return None diff --git a/haystack/nodes/prompt/invocation_layer/aws_base.py b/haystack/nodes/prompt/invocation_layer/aws_base.py new file mode 100644 index 000000000..66a0ff7f5 --- /dev/null +++ b/haystack/nodes/prompt/invocation_layer/aws_base.py @@ -0,0 +1,79 @@ +import logging +from abc import ABC +from typing import Optional + + +from haystack.errors import AWSConfigurationError +from haystack.lazy_imports import LazyImport +from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import: + import boto3 + from botocore.exceptions import BotoCoreError + + +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + + +class AWSBaseInvocationLayer(PromptModelInvocationLayer, ABC): + """ + Base class for AWS based invocation layers. + """ + + @classmethod + def aws_configured(cls, **kwargs) -> bool: + """ + Checks whether this invocation layer is active. + :param kwargs: The kwargs passed down to the invocation layer. + :return: True if the invocation layer is active, False otherwise. + """ + aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS) + return aws_config_provided + + @classmethod + def get_aws_session( + cls, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, + ): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :return: The created AWS session. + """ + boto3_import.check() + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, + ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + raise AWSConfigurationError( + f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + ) from e diff --git a/haystack/nodes/prompt/invocation_layer/sagemaker_base.py b/haystack/nodes/prompt/invocation_layer/sagemaker_base.py index e2829e6d1..b18e0f1f6 100644 --- a/haystack/nodes/prompt/invocation_layer/sagemaker_base.py +++ b/haystack/nodes/prompt/invocation_layer/sagemaker_base.py @@ -1,12 +1,12 @@ import json import logging from abc import abstractmethod, ABC -from typing import Optional, Dict, Union, List, Any +from typing import Dict, Union, List, Any -from haystack.errors import SageMakerConfigurationError +from haystack.errors import AWSConfigurationError, SageMakerConfigurationError from haystack.lazy_imports import LazyImport -from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer +from haystack.nodes.prompt.invocation_layer.aws_base import AWSBaseInvocationLayer from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler logger = logging.getLogger(__name__) @@ -14,10 +14,10 @@ logger = logging.getLogger(__name__) with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import: import boto3 - from botocore.exceptions import ClientError, BotoCoreError + from botocore.exceptions import ClientError -class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC): +class SageMakerBaseInvocationLayer(AWSBaseInvocationLayer, ABC): """ Base class for SageMaker based invocation layers. """ @@ -70,18 +70,12 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC): :param model_name_or_path: The model_name_or_path to check. """ - aws_configuration_keys = [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ] - aws_config_provided = any(key in kwargs for key in aws_configuration_keys) - if aws_config_provided: - boto3_import.check() + if cls.aws_configured(**kwargs): # attempt to create a session with the provided credentials - session = cls.check_aws_connect(aws_configuration_keys, kwargs) + try: + session = cls.get_aws_session(**kwargs) + except AWSConfigurationError as e: + raise SageMakerConfigurationError(message=e.message) from e # is endpoint in service? cls.check_endpoint_in_service(session, model_name_or_path) @@ -91,24 +85,6 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC): return supported return False - @classmethod - def check_aws_connect(cls, aws_configuration_keys: List[str], kwargs): - """ - Checks if the provided AWS credentials are valid and can be used to connect to SageMaker. - :param aws_configuration_keys: The AWS configuration keys to check. - :param kwargs: The kwargs passed down to the SageMakerClient. - :return: The boto3 session. - """ - boto3_import.check() - try: - session = cls.create_session(**kwargs) - except BotoCoreError as e: - provided_aws_config = {k: v for k, v in kwargs.items() if k in aws_configuration_keys} - raise SageMakerConfigurationError( - f"Failed to initialize the session or client with provided AWS credentials {provided_aws_config}" - ) from e - return session - @classmethod def check_endpoint_in_service(cls, session: "boto3.Session", endpoint: str): """ @@ -176,33 +152,3 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC): if client: client.close() return True - - @classmethod - def create_session( - cls, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - **kwargs, - ): - """ - Creates an AWS Session with the given parameters. - - :param aws_access_key_id: AWS access key ID. - :param aws_secret_access_key: AWS secret access key. - :param aws_session_token: AWS session token. - :param aws_region_name: AWS region name. - :param aws_profile_name: AWS profile name. - :raise NoCredentialsError: If the AWS credentials are not provided or invalid. - :return: The created AWS Session. - """ - boto3_import.check() - return boto3.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - region_name=aws_region_name, - profile_name=aws_profile_name, - ) diff --git a/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py b/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py index 3abcf5781..263d765a1 100644 --- a/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py +++ b/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py @@ -97,7 +97,7 @@ class SageMakerHFInferenceInvocationLayer(SageMakerBaseInvocationLayer): """ super().__init__(model_name_or_path, max_length=max_length, **kwargs) try: - session = SageMakerHFInferenceInvocationLayer.create_session( + session = self.get_aws_session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, diff --git a/haystack/nodes/prompt/invocation_layer/sagemaker_hf_text_gen.py b/haystack/nodes/prompt/invocation_layer/sagemaker_hf_text_gen.py index 29ff56728..0ac51f065 100644 --- a/haystack/nodes/prompt/invocation_layer/sagemaker_hf_text_gen.py +++ b/haystack/nodes/prompt/invocation_layer/sagemaker_hf_text_gen.py @@ -88,7 +88,7 @@ class SageMakerHFTextGenerationInvocationLayer(SageMakerBaseInvocationLayer): """ super().__init__(model_name_or_path, max_length=max_length, **kwargs) try: - session = SageMakerHFTextGenerationInvocationLayer.create_session( + session = self.get_aws_session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, diff --git a/haystack/nodes/prompt/invocation_layer/sagemaker_meta.py b/haystack/nodes/prompt/invocation_layer/sagemaker_meta.py index b66af49f0..e77912adc 100644 --- a/haystack/nodes/prompt/invocation_layer/sagemaker_meta.py +++ b/haystack/nodes/prompt/invocation_layer/sagemaker_meta.py @@ -4,15 +4,16 @@ from typing import Optional, Dict, List, Any, Union import requests -from haystack.errors import SageMakerModelNotReadyError, SageMakerInferenceError, SageMakerConfigurationError -from haystack.lazy_imports import LazyImport +from haystack.errors import ( + AWSConfigurationError, + SageMakerModelNotReadyError, + SageMakerInferenceError, + SageMakerConfigurationError, +) from haystack.nodes.prompt.invocation_layer.sagemaker_base import SageMakerBaseInvocationLayer logger = logging.getLogger(__name__) -with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import: - pass - class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer): """ @@ -139,7 +140,7 @@ class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer): kwargs.setdefault("model_max_length", 4096) super().__init__(model_name_or_path, max_length=max_length, **kwargs) try: - session = SageMakerMetaInvocationLayer.create_session( + session = self.get_aws_session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, @@ -294,21 +295,15 @@ class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer): :param model_name_or_path: The model_name_or_path to check. """ - aws_configuration_keys = [ - "aws_access_key_id", - "aws_secret_access_key", - "aws_session_token", - "aws_region_name", - "aws_profile_name", - ] - aws_config_provided = any(key in kwargs for key in aws_configuration_keys) accept_eula = False if "aws_custom_attributes" in kwargs and isinstance(kwargs["aws_custom_attributes"], dict): accept_eula = kwargs["aws_custom_attributes"].get("accept_eula", False) - if aws_config_provided and accept_eula: - boto3_import.check() + if cls.aws_configured(**kwargs) and accept_eula: # attempt to create a session with the provided credentials - session = cls.check_aws_connect(aws_configuration_keys, kwargs) + try: + session = cls.get_aws_session(**kwargs) + except AWSConfigurationError as e: + raise SageMakerConfigurationError(message=e.message) from e # is endpoint in service? cls.check_endpoint_in_service(session, model_name_or_path) diff --git a/pyproject.toml b/pyproject.toml index 61c6251c9..89d36390e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,12 +152,8 @@ docstores-gpu = [ "farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch]", ] aws = [ - "boto3", - # Costraint botocore to avoid taking to much time to resolve the dependency tree. - # boto3 used to constraint it at this version more than a year ago. To avoid breaking - # people using old versions we use a similar constraint without upper bound. - # https://github.com/boto/boto3/blob/dae73bef223abbedfa7317a783070831febc0c90/setup.py#L16 - "botocore>=1.27", + # first version to support Amazon Bedrock + "boto3>=1.28.57", ] crawler = [ "selenium>=4.11.0" diff --git a/releasenotes/notes/bedrock-support-bce28e3078c85c12.yaml b/releasenotes/notes/bedrock-support-bce28e3078c85c12.yaml new file mode 100644 index 000000000..d31ec387c --- /dev/null +++ b/releasenotes/notes/bedrock-support-bce28e3078c85c12.yaml @@ -0,0 +1,16 @@ +--- +prelude: > + Haystack now supports Amazon Bedrock models, including all existing and previously announced + models, like Llama-2-70b-chat. To use these models, simply pass the model ID in the + model_name_or_path parameter, like you do for any other model. For details, see + [Amazon Bedrock Docmentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html). + + For example, the following code loads the Llama 2 Chat 13B model: + ```python + from haystack.nodes import PromptNode + + prompt_node = PromptNode(model_name_or_path="meta.llama2-13b-chat-v1") + ``` +features: + - | + You can use Amazon Bedrock models in Haystack. diff --git a/test/prompt/invocation_layer/test_amazon_bedrock.py b/test/prompt/invocation_layer/test_amazon_bedrock.py new file mode 100644 index 000000000..a605c68cc --- /dev/null +++ b/test/prompt/invocation_layer/test_amazon_bedrock.py @@ -0,0 +1,1033 @@ +from typing import Optional, Type +from unittest.mock import call, patch, MagicMock + +import pytest + +from haystack.lazy_imports import LazyImport + +from haystack.errors import AmazonBedrockConfigurationError +from haystack.nodes.prompt.invocation_layer import AmazonBedrockInvocationLayer +from haystack.nodes.prompt.invocation_layer.amazon_bedrock import ( + AI21LabsJurassic2Adapter, + AnthropicClaudeAdapter, + BedrockModelAdapter, + CohereCommandAdapter, + AmazonTitanAdapter, + MetaLlama2ChatAdapter, +) + +with LazyImport() as boto3_import: + from botocore.exceptions import BotoCoreError + + +# create a fixture with mocked boto3 client and session +@pytest.fixture +def mock_boto3_session(): + with patch("boto3.Session") as mock_client: + yield mock_client + + +@pytest.fixture +def mock_prompt_handler(): + with patch("haystack.nodes.prompt.invocation_layer.handlers.DefaultPromptHandler") as mock_prompt_handler: + yield mock_prompt_handler + + +@pytest.mark.unit +def test_default_constructor(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the default constructor sets the correct values + """ + + layer = AmazonBedrockInvocationLayer( + model_name_or_path="anthropic.claude-v2", + max_length=99, + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + aws_profile_name="some_fake_profile", + aws_region_name="fake_region", + ) + + assert layer.max_length == 99 + assert layer.model_name_or_path == "anthropic.claude-v2" + + assert layer.prompt_handler is not None + assert layer.prompt_handler.model_max_length == 4096 + + # assert mocked boto3 client called exactly once + mock_boto3_session.assert_called_once() + + # assert mocked boto3 client was called with the correct parameters + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + +@pytest.mark.unit +def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session): + """ + Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2 + """ + layer = AmazonBedrockInvocationLayer(model_name_or_path="anthropic.claude-v2", prompt_handler=mock_prompt_handler) + assert layer.prompt_handler is not None + assert layer.prompt_handler.model_max_length == 4096 + + +@pytest.mark.unit +def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session): + """ + Test that model_kwargs are correctly set in the constructor + """ + model_kwargs = {"temperature": 0.7} + + layer = AmazonBedrockInvocationLayer(model_name_or_path="anthropic.claude-v2", **model_kwargs) + assert "temperature" in layer.model_adapter.model_kwargs + assert layer.model_adapter.model_kwargs["temperature"] == 0.7 + + +@pytest.mark.unit +def test_constructor_with_empty_model_name(): + """ + Test that the constructor raises an error when the model_name_or_path is empty + """ + with pytest.raises(ValueError, match="cannot be None or empty string"): + AmazonBedrockInvocationLayer(model_name_or_path="") + + +@pytest.mark.unit +def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): + """ + Test invoke raises an error if no prompt is provided + """ + layer = AmazonBedrockInvocationLayer(model_name_or_path="anthropic.claude-v2") + with pytest.raises(ValueError, match="The model anthropic.claude-v2 requires a valid prompt."): + layer.invoke() + + +@pytest.mark.unit +def test_short_prompt_is_not_truncated(mock_boto3_session): + """ + Test that a short prompt is not truncated + """ + # Define a short mock prompt and its tokenized version + mock_prompt_text = "I am a tokenized prompt" + mock_prompt_tokens = mock_prompt_text.split() + + # Mock the tokenizer so it returns our predefined tokens + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = mock_prompt_tokens + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Since our mock prompt is 5 tokens long, it doesn't exceed the + # total limit (5 prompt tokens + 3 generated tokens < 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockInvocationLayer( + "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length + ) + prompt_after_resize = layer._ensure_token_limit(mock_prompt_text) + + # The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it + assert prompt_after_resize == mock_prompt_text + + +@pytest.mark.unit +def test_long_prompt_is_truncated(mock_boto3_session): + """ + Test that a long prompt is truncated + """ + # Define a long mock prompt and its tokenized version + long_prompt_text = "I am a tokenized prompt of length eight" + long_prompt_tokens = long_prompt_text.split() + + # _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit + truncated_prompt_text = "I am a tokenized prompt of length" + + # Mock the tokenizer to return our predefined tokens + # convert tokens to our predefined truncated text + mock_tokenizer = MagicMock() + mock_tokenizer.tokenize.return_value = long_prompt_tokens + mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text + + # We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens + # Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens) + max_length_generated_text = 3 + total_model_max_length = 10 + + with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer): + layer = AmazonBedrockInvocationLayer( + "anthropic.claude-v2", max_length=max_length_generated_text, model_max_length=total_model_max_length + ) + prompt_after_resize = layer._ensure_token_limit(long_prompt_text) + + # The prompt exceeds the limit, _ensure_token_limit truncates it + assert prompt_after_resize == truncated_prompt_text + + +@pytest.mark.unit +def test_supports_for_valid_aws_configuration(): + mock_session = MagicMock() + mock_session.client("bedrock").list_foundation_models.return_value = { + "modelSummaries": [{"modelId": "anthropic.claude-v2"}] + } + + # Patch the class method to return the mock session + with patch( + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", + return_value=mock_session, + ): + supported = AmazonBedrockInvocationLayer.supports( + model_name_or_path="anthropic.claude-v2", aws_profile_name="some_real_profile" + ) + args, kwargs = mock_session.client("bedrock").list_foundation_models.call_args + assert kwargs["byOutputModality"] == "TEXT" + + assert supported + + +@pytest.mark.unit +def test_supports_raises_on_invalid_aws_profile_name(): + with patch("boto3.Session") as mock_boto3_session: + mock_boto3_session.side_effect = BotoCoreError() + with pytest.raises(AmazonBedrockConfigurationError, match="Failed to initialize the session"): + AmazonBedrockInvocationLayer.supports( + model_name_or_path="anthropic.claude-v2", aws_profile_name="some_fake_profile" + ) + + +@pytest.mark.unit +def test_supports_for_invalid_bedrock_config(): + mock_session = MagicMock() + mock_session.client.side_effect = BotoCoreError() + + # Patch the class method to return the mock session + with patch( + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", + return_value=mock_session, + ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): + AmazonBedrockInvocationLayer.supports( + model_name_or_path="anthropic.claude-v2", aws_profile_name="some_real_profile" + ) + + +@pytest.mark.unit +def test_supports_for_invalid_bedrock_config_error_on_list_models(): + mock_session = MagicMock() + mock_session.client("bedrock").list_foundation_models.side_effect = BotoCoreError() + + # Patch the class method to return the mock session + with patch( + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", + return_value=mock_session, + ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): + AmazonBedrockInvocationLayer.supports( + model_name_or_path="anthropic.claude-v2", aws_profile_name="some_real_profile" + ) + + +@pytest.mark.unit +def test_supports_for_no_aws_params(): + supported = AmazonBedrockInvocationLayer.supports(model_name_or_path="anthropic.claude-v2") + + assert supported == False + + +@pytest.mark.unit +def test_supports_for_unknown_model(): + supported = AmazonBedrockInvocationLayer.supports( + model_name_or_path="unknown_model", aws_profile_name="some_real_profile" + ) + + assert supported == False + + +@pytest.mark.unit +def test_supports_with_stream_true_for_model_that_supports_streaming(): + mock_session = MagicMock() + mock_session.client("bedrock").list_foundation_models.return_value = { + "modelSummaries": [{"modelId": "anthropic.claude-v2", "responseStreamingSupported": True}] + } + + # Patch the class method to return the mock session + with patch( + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", + return_value=mock_session, + ): + supported = AmazonBedrockInvocationLayer.supports( + model_name_or_path="anthropic.claude-v2", aws_profile_name="some_real_profile", stream=True + ) + + assert supported == True + + +@pytest.mark.unit +def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): + mock_session = MagicMock() + mock_session.client("bedrock").list_foundation_models.return_value = { + "modelSummaries": [{"modelId": "ai21.j2-mid-v1", "responseStreamingSupported": False}] + } + + # Patch the class method to return the mock session + with patch( + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", + return_value=mock_session, + ), pytest.raises(AmazonBedrockConfigurationError, match="The model ai21.j2-mid-v1 doesn't support streaming."): + AmazonBedrockInvocationLayer.supports( + model_name_or_path="ai21.j2-mid-v1", aws_profile_name="some_real_profile", stream=True + ) + + +@pytest.mark.unit +@pytest.mark.parametrize( + "model_name_or_path, expected_model_adapter", + [ + ("anthropic.claude-v1", AnthropicClaudeAdapter), + ("anthropic.claude-v2", AnthropicClaudeAdapter), + ("anthropic.claude-instant-v1", AnthropicClaudeAdapter), + ("anthropic.claude-super-v5", AnthropicClaudeAdapter), # artificial + ("cohere.command-text-v14", CohereCommandAdapter), + ("cohere.command-light-text-v14", CohereCommandAdapter), + ("cohere.command-text-v21", CohereCommandAdapter), # artificial + ("ai21.j2-mid-v1", AI21LabsJurassic2Adapter), + ("ai21.j2-ultra-v1", AI21LabsJurassic2Adapter), + ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial + ("amazon.titan-text-lite-v1", AmazonTitanAdapter), + ("amazon.titan-text-express-v1", AmazonTitanAdapter), + ("amazon.titan-text-agile-v1", AmazonTitanAdapter), + ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial + ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), + ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), + ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("unknown_model", None), + ], +) +def test_get_model_adapter(model_name_or_path: str, expected_model_adapter: Optional[Type[BedrockModelAdapter]]): + """ + Test that the correct model adapter is returned for a given model_name_or_path + """ + model_adapter = AmazonBedrockInvocationLayer.get_model_adapter(model_name_or_path=model_name_or_path) + assert model_adapter == expected_model_adapter + + +class TestAnthropicClaudeAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", + "max_tokens_to_sample": 99, + "stop_sequences": ["\n\nHuman:"], + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", + "max_tokens_to_sample": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + top_p=0.8, + top_k=5, + max_tokens_to_sample=50, + stop_sequences=["CUSTOM_STOP"], + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + "max_tokens_to_sample": 50, + "stop_sequences": ["CUSTOM_STOP"], + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", + "max_tokens_to_sample": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = AnthropicClaudeAdapter( + model_kwargs={ + "temperature": 0.6, + "top_p": 0.7, + "top_k": 4, + "max_tokens_to_sample": 49, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant:", + "max_tokens_to_sample": 50, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "top_p": 0.8, + "top_k": 5, + } + + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, top_k=5, max_tokens_to_sample=50) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"completion": "This is a single response."} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + response_body = {"completion": "\n\t This is a single response."} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"completion": " This"}'}}, + {"chunk": {"bytes": b'{"completion": " is"}'}}, + {"chunk": {"bytes": b'{"completion": " a"}'}}, + {"chunk": {"bytes": b'{"completion": " single"}'}}, + {"chunk": {"bytes": b'{"completion": " response."}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"completion": " This"}), + call(" is", event_data={"completion": " is"}), + call(" a", event_data={"completion": " a"}), + call(" single", event_data={"completion": " single"}), + call(" response.", event_data={"completion": " response."}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AnthropicClaudeAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() + + +class TestCohereCommandAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = CohereCommandAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_tokens": 99} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = CohereCommandAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "p": 0.8, + "k": 5, + "return_likelihoods": "GENERATION", + "stream": True, + "logit_bias": {"token_id": 10.0}, + "num_generations": 1, + "truncate": "START", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + p=0.8, + k=5, + max_tokens=50, + stop_sequences=["CUSTOM_STOP"], + return_likelihoods="GENERATION", + stream=True, + logit_bias={"token_id": 10.0}, + num_generations=1, + truncate="START", + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = CohereCommandAdapter( + model_kwargs={ + "temperature": 0.7, + "p": 0.8, + "k": 5, + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "return_likelihoods": "GENERATION", + "stream": True, + "logit_bias": {"token_id": 10.0}, + "num_generations": 1, + "truncate": "START", + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "p": 0.8, + "k": 5, + "return_likelihoods": "GENERATION", + "stream": True, + "logit_bias": {"token_id": 10.0}, + "num_generations": 1, + "truncate": "START", + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = CohereCommandAdapter( + model_kwargs={ + "temperature": 0.6, + "p": 0.7, + "k": 4, + "max_tokens": 49, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "return_likelihoods": "ALL", + "stream": False, + "logit_bias": {"token_id": 9.0}, + "num_generations": 2, + "truncate": "NONE", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "max_tokens": 50, + "stop_sequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "p": 0.8, + "k": 5, + "return_likelihoods": "GENERATION", + "stream": True, + "logit_bias": {"token_id": 10.0}, + "num_generations": 1, + "truncate": "START", + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + p=0.8, + k=5, + max_tokens=50, + return_likelihoods="GENERATION", + stream=True, + logit_bias={"token_id": 10.0}, + num_generations=1, + truncate="START", + ) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) + response_body = {"generations": [{"text": "This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) + response_body = {"generations": [{"text": "\n\t This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_multiple_responses(self) -> None: + adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) + response_body = { + "generations": [{"text": "This is a single response."}, {"text": "This is a second response."}] + } + expected_responses = ["This is a single response.", "This is a second response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"text": " This"}'}}, + {"chunk": {"bytes": b'{"text": " is"}'}}, + {"chunk": {"bytes": b'{"text": " a"}'}}, + {"chunk": {"bytes": b'{"text": " single"}'}}, + {"chunk": {"bytes": b'{"text": " response."}'}}, + {"chunk": {"bytes": b'{"finish_reason": "MAX_TOKENS", "is_finished": true}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"text": " This"}), + call(" is", event_data={"text": " is"}), + call(" a", event_data={"text": " a"}), + call(" single", event_data={"text": " single"}), + call(" response.", event_data={"text": " response."}), + call("", event_data={"finish_reason": "MAX_TOKENS", "is_finished": True}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = CohereCommandAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() + + +class TestAI21LabsJurrasic2Adapter: + def test_prepare_body_with_default_params(self) -> None: + layer = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "maxTokens": 99} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "maxTokens": 50, + "stopSequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "topP": 0.8, + "countPenalty": {"scale": 1.0}, + "presencePenalty": {"scale": 5.0}, + "frequencyPenalty": {"scale": 500.0}, + "numResults": 1, + } + + body = layer.prepare_body( + prompt, + maxTokens=50, + stopSequences=["CUSTOM_STOP"], + temperature=0.7, + topP=0.8, + countPenalty={"scale": 1.0}, + presencePenalty={"scale": 5.0}, + frequencyPenalty={"scale": 500.0}, + numResults=1, + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = AI21LabsJurassic2Adapter( + model_kwargs={ + "maxTokens": 50, + "stopSequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "topP": 0.8, + "countPenalty": {"scale": 1.0}, + "presencePenalty": {"scale": 5.0}, + "frequencyPenalty": {"scale": 500.0}, + "numResults": 1, + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "maxTokens": 50, + "stopSequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "topP": 0.8, + "countPenalty": {"scale": 1.0}, + "presencePenalty": {"scale": 5.0}, + "frequencyPenalty": {"scale": 500.0}, + "numResults": 1, + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = AI21LabsJurassic2Adapter( + model_kwargs={ + "maxTokens": 49, + "stopSequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.6, + "topP": 0.7, + "countPenalty": {"scale": 0.9}, + "presencePenalty": {"scale": 4.0}, + "frequencyPenalty": {"scale": 499.0}, + "numResults": 2, + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "prompt": "Hello, how are you?", + "maxTokens": 50, + "stopSequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "topP": 0.8, + "countPenalty": {"scale": 1.0}, + "presencePenalty": {"scale": 5.0}, + "frequencyPenalty": {"scale": 500.0}, + "numResults": 1, + } + + body = layer.prepare_body( + prompt, + temperature=0.7, + topP=0.8, + maxTokens=50, + countPenalty={"scale": 1.0}, + presencePenalty={"scale": 5.0}, + frequencyPenalty={"scale": 500.0}, + numResults=1, + ) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) + response_body = {"completions": [{"data": {"text": "This is a single response."}}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) + response_body = {"completions": [{"data": {"text": "\n\t This is a single response."}}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_multiple_responses(self) -> None: + adapter = AI21LabsJurassic2Adapter(model_kwargs={}, max_length=99) + response_body = { + "completions": [ + {"data": {"text": "This is a single response."}}, + {"data": {"text": "This is a second response."}}, + ] + } + expected_responses = ["This is a single response.", "This is a second response."] + assert adapter.get_responses(response_body) == expected_responses + + +class TestAmazonTitanAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = AmazonTitanAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"inputText": "Hello, how are you?", "textGenerationConfig": {"maxTokenCount": 99}} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = AmazonTitanAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = { + "inputText": "Hello, how are you?", + "textGenerationConfig": { + "maxTokenCount": 50, + "stopSequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "topP": 0.8, + }, + } + + body = layer.prepare_body( + prompt, + maxTokenCount=50, + stopSequences=["CUSTOM_STOP"], + temperature=0.7, + topP=0.8, + unknown_arg="unknown_value", + ) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = AmazonTitanAdapter( + model_kwargs={ + "maxTokenCount": 50, + "stopSequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "topP": 0.8, + "unknown_arg": "unknown_value", + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "inputText": "Hello, how are you?", + "textGenerationConfig": { + "maxTokenCount": 50, + "stopSequences": ["CUSTOM_STOP"], + "temperature": 0.7, + "topP": 0.8, + }, + } + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = AmazonTitanAdapter( + model_kwargs={ + "maxTokenCount": 49, + "stopSequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.6, + "topP": 0.7, + }, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = { + "inputText": "Hello, how are you?", + "textGenerationConfig": { + "maxTokenCount": 50, + "stopSequences": ["CUSTOM_STOP_MODEL_KWARGS"], + "temperature": 0.7, + "topP": 0.8, + }, + } + + body = layer.prepare_body(prompt, temperature=0.7, topP=0.8, maxTokenCount=50) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) + response_body = {"results": [{"outputText": "This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) + response_body = {"results": [{"outputText": "\n\t This is a single response."}]} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_multiple_responses(self) -> None: + adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) + response_body = { + "results": [{"outputText": "This is a single response."}, {"outputText": "This is a second response."}] + } + expected_responses = ["This is a single response.", "This is a second response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"outputText": " This"}'}}, + {"chunk": {"bytes": b'{"outputText": " is"}'}}, + {"chunk": {"bytes": b'{"outputText": " a"}'}}, + {"chunk": {"bytes": b'{"outputText": " single"}'}}, + {"chunk": {"bytes": b'{"outputText": " response."}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"outputText": " This"}), + call(" is", event_data={"outputText": " is"}), + call(" a", event_data={"outputText": " a"}), + call(" single", event_data={"outputText": " single"}), + call(" response.", event_data={"outputText": " response."}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = AmazonTitanAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() + + +class TestMetaLlama2ChatAdapter: + def test_prepare_body_with_default_params(self) -> None: + layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 99} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_custom_inference_params(self) -> None: + layer = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} + + body = layer.prepare_body(prompt, temperature=0.7, top_p=0.8, max_gen_len=50, unknown_arg="unknown_value") + + assert body == expected_body + + def test_prepare_body_with_model_kwargs(self) -> None: + layer = MetaLlama2ChatAdapter( + model_kwargs={"temperature": 0.7, "top_p": 0.8, "max_gen_len": 50, "unknown_arg": "unknown_value"}, + max_length=99, + ) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.8} + + body = layer.prepare_body(prompt) + + assert body == expected_body + + def test_prepare_body_with_model_kwargs_and_custom_inference_params(self) -> None: + layer = MetaLlama2ChatAdapter( + model_kwargs={"temperature": 0.6, "top_p": 0.7, "top_k": 4, "max_gen_len": 49}, max_length=99 + ) + prompt = "Hello, how are you?" + expected_body = {"prompt": "Hello, how are you?", "max_gen_len": 50, "temperature": 0.7, "top_p": 0.7} + + body = layer.prepare_body(prompt, temperature=0.7, max_gen_len=50) + + assert body == expected_body + + def test_get_responses(self) -> None: + adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + response_body = {"generation": "This is a single response."} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_responses_leading_whitespace(self) -> None: + adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + response_body = {"generation": "\n\t This is a single response."} + expected_responses = ["This is a single response."] + assert adapter.get_responses(response_body) == expected_responses + + def test_get_stream_responses(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [ + {"chunk": {"bytes": b'{"generation": " This"}'}}, + {"chunk": {"bytes": b'{"generation": " is"}'}}, + {"chunk": {"bytes": b'{"generation": " a"}'}}, + {"chunk": {"bytes": b'{"generation": " single"}'}}, + {"chunk": {"bytes": b'{"generation": " response."}'}}, + ] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + expected_responses = ["This is a single response."] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_has_calls( + [ + call(" This", event_data={"generation": " This"}), + call(" is", event_data={"generation": " is"}), + call(" a", event_data={"generation": " a"}), + call(" single", event_data={"generation": " single"}), + call(" response.", event_data={"generation": " response."}), + ] + ) + + def test_get_stream_responses_empty(self) -> None: + stream_mock = MagicMock() + stream_handler_mock = MagicMock() + + stream_mock.__iter__.return_value = [] + + stream_handler_mock.side_effect = lambda token_received, **kwargs: token_received + + adapter = MetaLlama2ChatAdapter(model_kwargs={}, max_length=99) + expected_responses = [""] + assert adapter.get_stream_responses(stream_mock, stream_handler_mock) == expected_responses + + stream_handler_mock.assert_not_called() diff --git a/test/prompt/invocation_layer/test_invocation_layers.py b/test/prompt/invocation_layer/test_invocation_layers.py index ccf17b697..e2a36998f 100644 --- a/test/prompt/invocation_layer/test_invocation_layers.py +++ b/test/prompt/invocation_layer/test_invocation_layers.py @@ -15,5 +15,5 @@ def test_invocation_layer_order(): assert HFInferenceEndpointInvocationLayer in invocation_layers index_hf = invocation_layers.index(HFLocalInvocationLayer) + 1 index_hf_inference = invocation_layers.index(HFInferenceEndpointInvocationLayer) + 1 - assert index_hf > len(invocation_layers) / 2 - assert index_hf_inference > len(invocation_layers) / 2 + assert index_hf >= 7 + assert index_hf_inference >= 7 diff --git a/test/prompt/invocation_layer/test_sagemaker_hf_infer.py b/test/prompt/invocation_layer/test_sagemaker_hf_infer.py index 2d2d66515..b2b803212 100644 --- a/test/prompt/invocation_layer/test_sagemaker_hf_infer.py +++ b/test/prompt/invocation_layer/test_sagemaker_hf_infer.py @@ -223,7 +223,7 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session", + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", return_value=mock_session, ): supported = SageMakerHFInferenceInvocationLayer.supports( @@ -245,12 +245,10 @@ def test_supports_not_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() - with pytest.raises(SageMakerConfigurationError) as exc_info: - supported = SageMakerHFInferenceInvocationLayer.supports( + with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"): + SageMakerHFInferenceInvocationLayer.supports( model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile" ) - assert "Failed to initialize the session" in exc_info.value - assert not supported @pytest.mark.skipif( diff --git a/test/prompt/invocation_layer/test_sagemaker_hf_text_gen.py b/test/prompt/invocation_layer/test_sagemaker_hf_text_gen.py index b37d58009..13b31b81d 100644 --- a/test/prompt/invocation_layer/test_sagemaker_hf_text_gen.py +++ b/test/prompt/invocation_layer/test_sagemaker_hf_text_gen.py @@ -81,9 +81,8 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): Test that invoke raises an error if no prompt is provided """ layer = SageMakerHFTextGenerationInvocationLayer(model_name_or_path="some_fake_model") - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="No prompt provided."): layer.invoke() - assert e.match("No prompt provided.") @pytest.mark.unit @@ -223,7 +222,7 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session", + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", return_value=mock_session, ): supported = SageMakerHFTextGenerationInvocationLayer.supports( @@ -245,12 +244,10 @@ def test_supports_not_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() - with pytest.raises(SageMakerConfigurationError) as exc_info: - supported = SageMakerHFTextGenerationInvocationLayer.supports( + with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"): + SageMakerHFTextGenerationInvocationLayer.supports( model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile" ) - assert "Failed to initialize the session" in exc_info.value - assert not supported @pytest.mark.skipif( diff --git a/test/prompt/invocation_layer/test_sagemaker_meta.py b/test/prompt/invocation_layer/test_sagemaker_meta.py index 5081a19af..6221a07e7 100644 --- a/test/prompt/invocation_layer/test_sagemaker_meta.py +++ b/test/prompt/invocation_layer/test_sagemaker_meta.py @@ -97,9 +97,8 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session): Test invoke raises an error if no prompt is provided """ layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model") - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError, match="No valid prompt provided."): layer.invoke() - assert e.match("No prompt provided.") @pytest.mark.unit @@ -239,7 +238,7 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session", + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", return_value=mock_session, ): supported = SageMakerMetaInvocationLayer.supports( @@ -263,14 +262,12 @@ def test_supports_not_on_invalid_aws_profile_name(): with patch("boto3.Session") as mock_boto3_session: mock_boto3_session.side_effect = BotoCoreError() - with pytest.raises(SageMakerConfigurationError) as exc_info: - supported = SageMakerMetaInvocationLayer.supports( + with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"): + SageMakerMetaInvocationLayer.supports( model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile", aws_custom_attributes={"accept_eula": True}, ) - assert "Failed to initialize the session" in exc_info.value - assert not supported @pytest.mark.unit @@ -287,7 +284,7 @@ def test_supports_not_on_missing_eula(): # Patch the class method to return the mock session with patch( - "haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session", + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", return_value=mock_session, ): supported = SageMakerMetaInvocationLayer.supports( @@ -311,7 +308,7 @@ def test_supports_not_on_eula_not_accepted(): # Patch the class method to return the mock session with patch( - "haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session", + "haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session", return_value=mock_session, ): supported = SageMakerMetaInvocationLayer.supports(