feat: Enable Support for Meta LLama-2 Models in Amazon Sagemaker (#5437)

* Enable Support for Meta LLama-2 Models in Amazon Sagemaker

* Improve unit test for invocation layers positioning

* Small adjustment, add more unit tests

* mypy fixes

* Improve unit tests

* Update test/prompt/invocation_layer/test_sagemaker_meta.py

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>

* PR feedback

* Add pydocs for newly extracted methods

* simplify is_proper_chat_*

---------

Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-07-26 15:26:39 +02:00 committed by GitHub
parent 9ab6298f1d
commit 409e3471cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 834 additions and 6 deletions

View File

@ -8,5 +8,6 @@ from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicCla
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_meta import SageMakerMetaInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer

View File

@ -87,7 +87,7 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
test_payload = cls.get_test_payload()
# send test payload to endpoint to see if it's supported
supported = cls.check_model_input_format(session, model_name_or_path, test_payload)
supported = cls.check_model_input_format(session, model_name_or_path, test_payload, **kwargs)
return supported
return False
@ -137,7 +137,18 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
client.close()
@classmethod
def check_model_input_format(cls, session: "boto3.Session", endpoint: str, test_payload: Dict[str, str]):
def format_custom_attributes(cls, attributes: dict) -> str:
"""
Formats the custom attributes for the SageMaker endpoint.
:param attributes: The custom attributes to format.
:return: The formatted custom attributes.
"""
if attributes:
return ";".join(f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in attributes.items())
return ""
@classmethod
def check_model_input_format(cls, session: "boto3.Session", endpoint: str, test_payload: Any, **kwargs):
"""
Checks if the SageMaker endpoint supports the test_payload model input format.
:param session: The boto3 session.
@ -146,6 +157,8 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
:return: True if the endpoint supports the test_payload model input format, False otherwise.
"""
boto3_import.check()
custom_attributes = kwargs.get("aws_custom_attributes", None)
custom_attributes = SageMakerBaseInvocationLayer.format_custom_attributes(custom_attributes)
client = None
try:
client = session.client("sagemaker-runtime")
@ -154,6 +167,7 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
Body=json.dumps(test_payload),
ContentType="application/json",
Accept="application/json",
CustomAttributes=custom_attributes,
)
except ClientError:
# raised if the endpoint doesn't support the test_payload model input format

View File

@ -0,0 +1,335 @@
import json
import logging
from typing import Optional, Dict, List, Any, Union
import requests
from haystack.errors import SageMakerModelNotReadyError, SageMakerInferenceError, SageMakerConfigurationError
from haystack.lazy_imports import LazyImport
from haystack.nodes.prompt.invocation_layer.sagemaker_base import SageMakerBaseInvocationLayer
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
pass
class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
"""
SageMaker Meta Invocation Layer
SageMakerMetaInvocationLayer enables the use of Meta Large Language Models (LLMs) hosted on a SageMaker
Inference Endpoint via PromptNode. It primarily focuses on LLama-2 models and it supports both the chat and
instruction following models. Other Meta models have not been tested.
For guidance on how to deploy such a model to SageMaker, refer to
the [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html)
and follow the instructions provided there.
As of July 24, this layer has been confirmed to support the following SageMaker deployed models:
- Llama-2 models
Technical Note:
This layer is designed for models that anticipate an input format composed of the following keys/values:
{'inputs': 'prompt_text', 'parameters': params} where 'inputs' represents the prompt and 'parameters' the
parameters for the model.
**Examples**
```python
from haystack.nodes import PromptNode
# Pass sagemaker endpoint name and authentication details
pn = PromptNode(model_name_or_path="llama-2-7b",
model_kwargs={"aws_profile_name": "my_aws_profile_name"})
res = pn("Berlin is the capital of")
print(res)
```
**Example using AWS env variables**
```python
import os
from haystack.nodes import PromptNode
# We can also configure Sagemaker via AWS environment variables without AWS profile name
pn = PromptNode(model_name_or_path="llama-2-7b", max_length=512,
model_kwargs={"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
"aws_session_token": os.getenv("AWS_SESSION_TOKEN"),
"aws_region_name": "us-east-1"})
response = pn("The secret for a good life is")
print(response)
```
LLama-2 also supports chat format.
**Example using chat format**
```python
from haystack.nodes.prompt import PromptNode
pn = PromptNode(model_name_or_path="llama-2-7b-chat", max_length=512, model_kwargs={"aws_profile_name": "default",
"aws_custom_attributes": {"accept_eula": True}})
pn_input = [[{"role": "user", "content": "what is the recipe of mayonnaise?"}]]
response = pn(pn_input)
print(response)
```
Note that in the chat examples we can also include multiple turns between the user and the assistant. See the
Llama-2 chat documentation for more details.
**Example using chat format with multiple turns**
```python
from haystack.nodes.prompt import PromptNode
pn = PromptNode(model_name_or_path="llama-2-7b-chat", max_length=512, model_kwargs={"aws_profile_name": "default",
"aws_custom_attributes": {"accept_eula": True}})
pn_input = [[
{"role": "user", "content": "I am going to Paris, what should I see?"},
{"role": "assistant", "content": "Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n
1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n
These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.",},
{"role": "user", "content": "What is so great about #1?"}]]
response = pn(pn_input)
print(response)
```
Llama-2 models support the following inference payload parameters:
max_new_tokens: Model generates text until the output length (excluding the input context length) reaches
max_new_tokens. If specified, it must be a positive integer.
temperature: Controls the randomness in the output. Higher temperature results in output sequence with
low-probability words and lower temperature results in output sequence with high-probability words.
If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.
top_p: In each step of text generation, sample from the smallest possible set of words with cumulative
probability top_p. If specified, it must be a float between 0 and 1.
return_full_text: If True, input text will be part of the output generated text. If specified, it must be
boolean. The default value for it is False.
Of course, in both examples your endpoints, region names and other settings will be different.
You can find it in the SageMaker AWS console.
"""
def __init__(
self,
model_name_or_path: str,
max_length: int = 100,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Instantiates the session with SageMaker using IAM based authentication via boto3.
:param model_name_or_path: The name for SageMaker Model Endpoint.
:param max_length: The maximum length of the output text.
:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
"""
# We set the default max_length to 4096 as this is the context window size supported by the LLama-2 model
kwargs.setdefault("model_max_length", 4096)
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
try:
session = SageMakerMetaInvocationLayer.create_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
)
self.client = session.client("sagemaker-runtime")
except Exception as e:
raise SageMakerInferenceError(
f"Could not connect to SageMaker Inference Endpoint {model_name_or_path}."
f"Make sure the Endpoint exists and AWS environment is configured."
) from e
# save the kwargs for the model invocation
self.model_input_kwargs = kwargs
# As of July 24, SageMaker Meta layer does not support streaming responses.
# However, even though it's not provided, users may attempt to use streaming responses.
# Use stream and stream_handler for warning and future use
self.stream_handler = kwargs.get("stream_handler", None)
self.stream = kwargs.get("stream", False)
def invoke(self, *args, **kwargs) -> List[str]:
"""
Sends the prompt to the remote model and returns the generated response(s).
:return: The generated responses from the model as a list of strings.
"""
prompt: Any = kwargs.get("prompt")
if not prompt or not isinstance(prompt, (str, list)):
raise ValueError(
f"No valid prompt provided. Model {self.model_name_or_path} requires a valid prompt."
f"Make sure to provide a prompt in the format that the model expects."
)
if not (isinstance(prompt, str) or self.is_proper_chat_conversation_format(prompt)):
raise ValueError(
f"The prompt format is different than what the model expects. "
f"The model {self.model_name_or_path} requires either a string or messages in the specific chat format. "
f"For more details, see https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L213)."
)
stream = kwargs.get("stream", self.stream)
stream_handler = kwargs.get("stream_handler", self.stream_handler)
streaming_requested = stream or stream_handler is not None
if streaming_requested:
raise SageMakerConfigurationError("SageMaker model response streaming is not supported yet")
kwargs_with_defaults = self.model_input_kwargs
kwargs_with_defaults.update(kwargs)
default_params = {
"max_new_tokens": self.max_length,
"return_full_text": None,
"temperature": None,
"top_p": None,
}
# put the param in the params if it's in kwargs and not None (e.g. it is actually defined)
# endpoint doesn't tolerate None values, send only the params that are defined
params = {
param: kwargs_with_defaults.get(param, default)
for param, default in default_params.items()
if param in kwargs_with_defaults or default is not None
}
generated_texts = self._post(prompt=prompt, params=params)
return generated_texts
def _post(self, prompt: Any, params: Optional[Dict[str, Any]] = None) -> List[str]:
"""
Post data to the SageMaker inference model. It takes in a prompt and returns a list of responses using model
invocation.
:param prompt: The prompt text/messages to be sent to the model.
:param params: The parameters to be sent to the Meta model.
:return: The generated responses as a list of strings.
"""
custom_attributes = SageMakerBaseInvocationLayer.format_custom_attributes(
self.model_input_kwargs.get("aws_custom_attributes", {})
)
try:
body = {"inputs": prompt, "parameters": params}
response = self.client.invoke_endpoint(
EndpointName=self.model_name_or_path,
Body=json.dumps(body),
ContentType="application/json",
Accept="application/json",
CustomAttributes=custom_attributes,
)
response_json = response.get("Body").read().decode("utf-8")
output = json.loads(response_json)
generated_texts = [o["generation"] for o in output if "generation" in o]
return generated_texts
except requests.HTTPError as err:
res = err.response
if res.status_code == 429:
raise SageMakerModelNotReadyError(f"Model not ready: {res.text}") from err
raise SageMakerInferenceError(
f"SageMaker Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
status_code=res.status_code,
) from err
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# the prompt for this model will be of the type str
if isinstance(prompt, str):
return super()._ensure_token_limit(prompt)
else:
# TODO: implement truncation for the chat format
return prompt
def _is_proper_chat_message_format(self, chat_message: Dict[str, str]) -> bool:
"""
Checks whether a chat message is in the proper format.
:param chat_message: The chat message to be checked.
:return: True if the chat message is in the proper format, False otherwise.
"""
allowed_roles = {"user", "assistant", "system"}
return (
isinstance(chat_message, dict)
and "role" in chat_message
and "content" in chat_message
and chat_message["role"] in allowed_roles
)
def is_proper_chat_conversation_format(self, prompt: List[Any]) -> bool:
"""
Checks whether a chat conversation is in the proper format.
:param prompt: The chat conversation to be checked.
:return: True if the chat conversation is in the proper format, False otherwise.
"""
if not isinstance(prompt, list) or len(prompt) == 0:
return False
return all(
isinstance(message_list, list)
and all(self._is_proper_chat_message_format(chat_message) for chat_message in message_list)
for message_list in prompt
)
@classmethod
def get_test_payload(cls) -> Dict[str, Any]:
"""
Return test payload for the model.
"""
# implement the abstract method to fulfill the contract, but it won't be used
# because we override the supports method to check support
# for the chat and instruction following format manually
return {}
@classmethod
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
"""
Checks whether a model_name_or_path passed down (e.g. via PromptNode) is supported by this class.
:param model_name_or_path: The model_name_or_path to check.
"""
aws_configuration_keys = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]
aws_config_provided = any(key in kwargs for key in aws_configuration_keys)
accept_eula = False
if "aws_custom_attributes" in kwargs and isinstance(kwargs["aws_custom_attributes"], dict):
accept_eula = kwargs["aws_custom_attributes"].get("accept_eula", False)
if aws_config_provided and accept_eula:
boto3_import.check()
# attempt to create a session with the provided credentials
session = cls.check_aws_connect(aws_configuration_keys, kwargs)
# is endpoint in service?
cls.check_endpoint_in_service(session, model_name_or_path)
# let's test both formats as we want to support both chat and instruction following models:
# 1. instruction following format
# send test payload to endpoint to see if it's supported
instruction_test_payload: Dict[str, Any] = {
"inputs": "Hello world",
# don't remove max_new_tokens param, if we don't specify it, the model will generate 4k tokens
"parameters": {"max_new_tokens": 10},
}
supported = cls.check_model_input_format(session, model_name_or_path, instruction_test_payload, **kwargs)
if supported:
return True
# 2. chat format
chat_test_payload: Dict[str, Any] = {
"inputs": [[{"role": "user", "content": "what is the recipe of mayonnaise?"}]],
# don't remove max_new_tokens param, if we don't specify it, the model will generate 4k tokens
"parameters": {"max_new_tokens": 10},
}
supported = cls.check_model_input_format(session, model_name_or_path, chat_test_payload, **kwargs)
return supported
return False

View File

@ -7,8 +7,13 @@ from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInf
@pytest.mark.unit
def test_invocation_layer_order():
"""
Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond.
Checks that the huggingface invocation layer is positioned further down the list of providers
as they can time out or be slow to respond.
"""
last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:])
assert HFLocalInvocationLayer in last_invocation_layers
assert HFInferenceEndpointInvocationLayer in last_invocation_layers
invocation_layers = PromptModelInvocationLayer.invocation_layer_providers
assert HFLocalInvocationLayer in invocation_layers
assert HFInferenceEndpointInvocationLayer in invocation_layers
index_hf = invocation_layers.index(HFLocalInvocationLayer) + 1
index_hf_inference = invocation_layers.index(HFInferenceEndpointInvocationLayer) + 1
assert index_hf > len(invocation_layers) / 2
assert index_hf_inference > len(invocation_layers) / 2

View File

@ -0,0 +1,473 @@
import os
from unittest.mock import patch, MagicMock, Mock
import pytest
from haystack.lazy_imports import LazyImport
from haystack.errors import SageMakerConfigurationError
from haystack.nodes.prompt.invocation_layer import SageMakerMetaInvocationLayer
with LazyImport() as boto3_import:
from botocore.exceptions import BotoCoreError
# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client
@pytest.fixture
def mock_prompt_handler():
with patch("haystack.nodes.prompt.invocation_layer.handlers.DefaultPromptHandler") as mock_prompt_handler:
yield mock_prompt_handler
@pytest.mark.unit
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the default constructor sets the correct values
"""
layer = SageMakerMetaInvocationLayer(
model_name_or_path="some_fake_model",
max_length=99,
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
aws_profile_name="some_fake_profile",
aws_region_name="fake_region",
)
assert layer.max_length == 99
assert layer.model_name_or_path == "some_fake_model"
assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096
# assert mocked boto3 client called exactly once
mock_boto3_session.assert_called_once()
# assert mocked boto3 client was called with the correct parameters
mock_boto3_session.assert_called_with(
aws_access_key_id="some_fake_id",
aws_secret_access_key="some_fake_key",
aws_session_token="some_fake_token",
profile_name="some_fake_profile",
region_name="fake_region",
)
@pytest.mark.unit
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model", prompt_handler=mock_prompt_handler)
assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096
@pytest.mark.unit
def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session):
"""
Test that model_kwargs are correctly set in the constructor
"""
model_kwargs = {"temperature": 0.7}
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model", **model_kwargs)
assert "temperature" in layer.model_input_kwargs
assert layer.model_input_kwargs["temperature"] == 0.7
@pytest.mark.unit
def test_constructor_with_empty_model_name():
"""
Test that the constructor raises an error when the model_name_or_path is empty
"""
with pytest.raises(ValueError, match="cannot be None or empty string"):
SageMakerMetaInvocationLayer(model_name_or_path="")
@pytest.mark.unit
def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
"""
Test invoke raises an error if no prompt is provided
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
with pytest.raises(ValueError) as e:
layer.invoke()
assert e.match("No prompt provided.")
@pytest.mark.unit
def test_invoke_with_stop_words(mock_auto_tokenizer, mock_boto3_session):
"""
SageMakerMetaInvocationLayer does not support stop words. Tests that they'll be ignored
"""
stop_words = ["but", "not", "bye"]
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_model")
with patch("haystack.nodes.prompt.invocation_layer.SageMakerMetaInvocationLayer._post") as mock_post:
# Mock the response, need to return a list of dicts
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
assert mock_post.called
_, call_kwargs = mock_post.call_args
assert "stop_words" not in call_kwargs["params"]
@pytest.mark.unit
def test_short_prompt_is_not_truncated(mock_boto3_session):
"""
Test that a short prompt is not truncated
"""
# Define a short mock prompt and its tokenized version
mock_prompt_text = "I am a tokenized prompt"
mock_prompt_tokens = mock_prompt_text.split()
# Mock the tokenizer so it returns our predefined tokens
mock_tokenizer = MagicMock()
mock_tokenizer.tokenize.return_value = mock_prompt_tokens
# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
# Since our mock prompt is 5 tokens long, it doesn't exceed the
# total limit (5 prompt tokens + 3 generated tokens < 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10
with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = SageMakerMetaInvocationLayer(
"some_fake_endpoint", max_length=max_length_generated_text, model_max_length=total_model_max_length
)
prompt_after_resize = layer._ensure_token_limit(mock_prompt_text)
# The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it
assert prompt_after_resize == mock_prompt_text
@pytest.mark.unit
def test_long_prompt_is_truncated(mock_boto3_session):
"""
Test that a long prompt is truncated
"""
# Define a long mock prompt and its tokenized version
long_prompt_text = "I am a tokenized prompt of length eight"
long_prompt_tokens = long_prompt_text.split()
# _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit
truncated_prompt_text = "I am a tokenized prompt of length"
# Mock the tokenizer to return our predefined tokens
# convert tokens to our predefined truncated text
mock_tokenizer = MagicMock()
mock_tokenizer.tokenize.return_value = long_prompt_tokens
mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text
# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10
with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = SageMakerMetaInvocationLayer(
"some_fake_endpoint", max_length=max_length_generated_text, model_max_length=total_model_max_length
)
prompt_after_resize = layer._ensure_token_limit(long_prompt_text)
# The prompt exceeds the limit, _ensure_token_limit truncates it
assert prompt_after_resize == truncated_prompt_text
@pytest.mark.unit
def test_streaming_init_kwarg(mock_auto_tokenizer, mock_boto3_session):
"""
Test stream parameter passed as init kwarg raises an error on layer invocation
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant", stream=True)
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
layer.invoke(prompt="Tell me hello")
@pytest.mark.unit
def test_streaming_invoke_kwarg(mock_auto_tokenizer, mock_boto3_session):
"""
Test stream parameter passed as invoke kwarg raises an error on layer invocation
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant")
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
layer.invoke(prompt="Tell me hello", stream=True)
@pytest.mark.unit
def test_streaming_handler_init_kwarg(mock_auto_tokenizer, mock_boto3_session):
"""
Test stream_handler parameter passed as init kwarg raises an error on layer invocation
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant", stream_handler=Mock())
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
layer.invoke(prompt="Tell me hello")
@pytest.mark.unit
def test_streaming_handler_invoke_kwarg(mock_auto_tokenizer, mock_boto3_session):
"""
Test stream_handler parameter passed as invoke kwarg raises an error on layer invocation
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant")
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
layer.invoke(prompt="Tell me hello", stream_handler=Mock())
@pytest.mark.unit
def test_supports_for_valid_aws_configuration():
"""
Test that the SageMakerMetaInvocationLayer identifies a valid SageMaker Inference endpoint via the supports() method
"""
mock_client = MagicMock()
mock_client.describe_endpoint.return_value = {"EndpointStatus": "InService"}
mock_session = MagicMock()
mock_session.client.return_value = mock_client
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
return_value=mock_session,
):
supported = SageMakerMetaInvocationLayer.supports(
model_name_or_path="some_sagemaker_deployed_model",
aws_profile_name="some_real_profile",
aws_custom_attributes={"accept_eula": True},
)
args, kwargs = mock_client.describe_endpoint.call_args
assert kwargs["EndpointName"] == "some_sagemaker_deployed_model"
args, kwargs = mock_session.client.call_args
assert args[0] == "sagemaker-runtime"
assert supported
@pytest.mark.unit
def test_supports_not_on_invalid_aws_profile_name():
"""
Test that the SageMakerMetaInvocationLayer raises SageMakerConfigurationError when the profile name is invalid
"""
with patch("boto3.Session") as mock_boto3_session:
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(SageMakerConfigurationError) as exc_info:
supported = SageMakerMetaInvocationLayer.supports(
model_name_or_path="some_fake_model",
aws_profile_name="some_fake_profile",
aws_custom_attributes={"accept_eula": True},
)
assert "Failed to initialize the session" in exc_info.value
assert not supported
@pytest.mark.unit
def test_supports_not_on_missing_eula():
"""
Test that the SageMakerMetaInvocationLayer is not supported when the EULA is missing
"""
mock_client = MagicMock()
mock_client.describe_endpoint.return_value = {"EndpointStatus": "InService"}
mock_session = MagicMock()
mock_session.client.return_value = mock_client
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
return_value=mock_session,
):
supported = SageMakerMetaInvocationLayer.supports(
model_name_or_path="some_sagemaker_deployed_model", aws_profile_name="some_real_profile"
)
assert not supported
@pytest.mark.unit
def test_supports_not_on_eula_not_accepted():
"""
Test that the SageMakerMetaInvocationLayer is not supported when the EULA is not accepted
"""
mock_client = MagicMock()
mock_client.describe_endpoint.return_value = {"EndpointStatus": "InService"}
mock_session = MagicMock()
mock_session.client.return_value = mock_client
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
return_value=mock_session,
):
supported = SageMakerMetaInvocationLayer.supports(
model_name_or_path="some_sagemaker_deployed_model",
aws_profile_name="some_real_profile",
aws_custom_attributes={"accept_eula": False},
)
assert not supported
@pytest.mark.unit
def test_format_custom_attributes_with_non_empty_dict():
"""
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes specified
"""
attributes = {"key1": "value1", "key2": "value2"}
expected_output = "key1=value1;key2=value2"
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
@pytest.mark.unit
def test_format_custom_attributes_with_empty_dict():
"""
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes not specified
"""
attributes = {}
expected_output = ""
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
@pytest.mark.unit
def test_format_custom_attributes_with_none():
"""
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are None
"""
attributes = None
expected_output = ""
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
@pytest.mark.unit
def test_format_custom_attributes_with_bool_value():
"""
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are bool
"""
attributes = {"key1": True, "key2": False}
expected_output = "key1=true;key2=false"
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
@pytest.mark.unit
def test_format_custom_attributes_with_single_bool_value():
"""
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are single bool
"""
attributes = {"key1": True}
expected_output = "key1=true"
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
@pytest.mark.unit
def test_format_custom_attributes_with_int_value():
"""
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are ints
"""
attributes = {"key1": 1, "key2": 2}
expected_output = "key1=1;key2=2"
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
@pytest.mark.unit
def test_invoke_chat_format(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the SageMakerMetaInvocationLayer accepts a chat in the correct format
"""
# test the format of the chat, no exception should be raised
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
prompt = [[{"role": "user", "content": "Hello"}]]
expected_response = [[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hello there"}]]
with patch("haystack.nodes.prompt.invocation_layer.sagemaker_meta.SageMakerMetaInvocationLayer._post") as mock_post:
mock_post.return_value = expected_response
layer.invoke(prompt=prompt)
@pytest.mark.unit
def test_invoke_invalid_chat_format(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the SageMakerMetaInvocationLayer raises an exception when the chat is in the wrong format
"""
# test the invalid format of the chat, should raise an exception
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
prompt = [{"roe": "user", "cotent": "Hello"}]
expected_response = [[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hello there"}]]
with patch("haystack.nodes.prompt.invocation_layer.sagemaker_meta.SageMakerMetaInvocationLayer._post") as mock_post:
mock_post.return_value = expected_response
with pytest.raises(ValueError, match="The prompt format is different than what the model expects"):
layer.invoke(prompt=prompt)
@pytest.mark.unit
def test_invoke_prompt_string(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the SageMakerMetaInvocationLayer accepts a prompt in the correct string format
"""
# test the format of the prompt instruction, no exception should be raised
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
with patch("haystack.nodes.prompt.invocation_layer.sagemaker_meta.SageMakerMetaInvocationLayer._post") as mock_post:
mock_post.return_value = ["Hello there"]
layer.invoke(prompt="Hello")
@pytest.mark.unit
def test_invoke_empty_prompt(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the SageMakerMetaInvocationLayer raises an exception when the prompt is empty string
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
with pytest.raises(ValueError):
layer.invoke(prompt="")
@pytest.mark.unit
def test_invoke_improper_prompt_type(mock_auto_tokenizer, mock_boto3_session):
"""
Test that the SageMakerMetaInvocationLayer raises an exception when the prompt is int instead of str
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
prompt = 123
with pytest.raises(ValueError):
layer.invoke(prompt=prompt)
@pytest.mark.skipif(
not os.environ.get("TEST_SAGEMAKER_MODEL_ENDPOINT", None), reason="Skipping because SageMaker not configured"
)
@pytest.mark.integration
def test_supports_triggered_for_valid_sagemaker_endpoint():
"""
Test that the SageMakerMetaInvocationLayer identifies a valid SageMaker Inference endpoint via the supports() method
"""
model_name_or_path = os.environ.get("TEST_SAGEMAKER_MODEL_ENDPOINT")
assert SageMakerMetaInvocationLayer.supports(model_name_or_path=model_name_or_path)
@pytest.mark.skipif(
not os.environ.get("TEST_SAGEMAKER_MODEL_ENDPOINT", None), reason="Skipping because SageMaker not configured"
)
@pytest.mark.integration
def test_supports_not_triggered_for_invalid_iam_profile():
"""
Test that the SageMakerMetaInvocationLayer identifies an invalid SageMaker Inference endpoint
(in this case because of an invalid IAM AWS Profile via the supports() method)
"""
assert not SageMakerMetaInvocationLayer.supports(model_name_or_path="fake_endpoint")
assert not SageMakerMetaInvocationLayer.supports(
model_name_or_path="fake_endpoint", aws_profile_name="invalid-profile"
)