mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 02:09:39 +00:00
feat: Enable Support for Meta LLama-2 Models in Amazon Sagemaker (#5437)
* Enable Support for Meta LLama-2 Models in Amazon Sagemaker * Improve unit test for invocation layers positioning * Small adjustment, add more unit tests * mypy fixes * Improve unit tests * Update test/prompt/invocation_layer/test_sagemaker_meta.py Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> * PR feedback * Add pydocs for newly extracted methods * simplify is_proper_chat_* --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: anakin87 <stefanofiorucci@gmail.com>
This commit is contained in:
parent
9ab6298f1d
commit
409e3471cb
@ -8,5 +8,6 @@ from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicCla
|
||||
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_meta import SageMakerMetaInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer
|
||||
|
||||
@ -87,7 +87,7 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
|
||||
test_payload = cls.get_test_payload()
|
||||
# send test payload to endpoint to see if it's supported
|
||||
supported = cls.check_model_input_format(session, model_name_or_path, test_payload)
|
||||
supported = cls.check_model_input_format(session, model_name_or_path, test_payload, **kwargs)
|
||||
return supported
|
||||
return False
|
||||
|
||||
@ -137,7 +137,18 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
client.close()
|
||||
|
||||
@classmethod
|
||||
def check_model_input_format(cls, session: "boto3.Session", endpoint: str, test_payload: Dict[str, str]):
|
||||
def format_custom_attributes(cls, attributes: dict) -> str:
|
||||
"""
|
||||
Formats the custom attributes for the SageMaker endpoint.
|
||||
:param attributes: The custom attributes to format.
|
||||
:return: The formatted custom attributes.
|
||||
"""
|
||||
if attributes:
|
||||
return ";".join(f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in attributes.items())
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def check_model_input_format(cls, session: "boto3.Session", endpoint: str, test_payload: Any, **kwargs):
|
||||
"""
|
||||
Checks if the SageMaker endpoint supports the test_payload model input format.
|
||||
:param session: The boto3 session.
|
||||
@ -146,6 +157,8 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
:return: True if the endpoint supports the test_payload model input format, False otherwise.
|
||||
"""
|
||||
boto3_import.check()
|
||||
custom_attributes = kwargs.get("aws_custom_attributes", None)
|
||||
custom_attributes = SageMakerBaseInvocationLayer.format_custom_attributes(custom_attributes)
|
||||
client = None
|
||||
try:
|
||||
client = session.client("sagemaker-runtime")
|
||||
@ -154,6 +167,7 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
Body=json.dumps(test_payload),
|
||||
ContentType="application/json",
|
||||
Accept="application/json",
|
||||
CustomAttributes=custom_attributes,
|
||||
)
|
||||
except ClientError:
|
||||
# raised if the endpoint doesn't support the test_payload model input format
|
||||
|
||||
335
haystack/nodes/prompt/invocation_layer/sagemaker_meta.py
Normal file
335
haystack/nodes/prompt/invocation_layer/sagemaker_meta.py
Normal file
@ -0,0 +1,335 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, List, Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from haystack.errors import SageMakerModelNotReadyError, SageMakerInferenceError, SageMakerConfigurationError
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_base import SageMakerBaseInvocationLayer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
|
||||
pass
|
||||
|
||||
|
||||
class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
"""
|
||||
SageMaker Meta Invocation Layer
|
||||
|
||||
|
||||
SageMakerMetaInvocationLayer enables the use of Meta Large Language Models (LLMs) hosted on a SageMaker
|
||||
Inference Endpoint via PromptNode. It primarily focuses on LLama-2 models and it supports both the chat and
|
||||
instruction following models. Other Meta models have not been tested.
|
||||
|
||||
For guidance on how to deploy such a model to SageMaker, refer to
|
||||
the [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html)
|
||||
and follow the instructions provided there.
|
||||
|
||||
As of July 24, this layer has been confirmed to support the following SageMaker deployed models:
|
||||
- Llama-2 models
|
||||
|
||||
Technical Note:
|
||||
This layer is designed for models that anticipate an input format composed of the following keys/values:
|
||||
{'inputs': 'prompt_text', 'parameters': params} where 'inputs' represents the prompt and 'parameters' the
|
||||
parameters for the model.
|
||||
|
||||
|
||||
**Examples**
|
||||
|
||||
```python
|
||||
from haystack.nodes import PromptNode
|
||||
|
||||
# Pass sagemaker endpoint name and authentication details
|
||||
pn = PromptNode(model_name_or_path="llama-2-7b",
|
||||
model_kwargs={"aws_profile_name": "my_aws_profile_name"})
|
||||
res = pn("Berlin is the capital of")
|
||||
print(res)
|
||||
```
|
||||
|
||||
**Example using AWS env variables**
|
||||
```python
|
||||
import os
|
||||
from haystack.nodes import PromptNode
|
||||
|
||||
# We can also configure Sagemaker via AWS environment variables without AWS profile name
|
||||
pn = PromptNode(model_name_or_path="llama-2-7b", max_length=512,
|
||||
model_kwargs={"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
"aws_session_token": os.getenv("AWS_SESSION_TOKEN"),
|
||||
"aws_region_name": "us-east-1"})
|
||||
|
||||
response = pn("The secret for a good life is")
|
||||
print(response)
|
||||
```
|
||||
|
||||
LLama-2 also supports chat format.
|
||||
**Example using chat format**
|
||||
```python
|
||||
from haystack.nodes.prompt import PromptNode
|
||||
pn = PromptNode(model_name_or_path="llama-2-7b-chat", max_length=512, model_kwargs={"aws_profile_name": "default",
|
||||
"aws_custom_attributes": {"accept_eula": True}})
|
||||
pn_input = [[{"role": "user", "content": "what is the recipe of mayonnaise?"}]]
|
||||
response = pn(pn_input)
|
||||
print(response)
|
||||
```
|
||||
|
||||
Note that in the chat examples we can also include multiple turns between the user and the assistant. See the
|
||||
Llama-2 chat documentation for more details.
|
||||
|
||||
**Example using chat format with multiple turns**
|
||||
```python
|
||||
from haystack.nodes.prompt import PromptNode
|
||||
pn = PromptNode(model_name_or_path="llama-2-7b-chat", max_length=512, model_kwargs={"aws_profile_name": "default",
|
||||
"aws_custom_attributes": {"accept_eula": True}})
|
||||
pn_input = [[
|
||||
{"role": "user", "content": "I am going to Paris, what should I see?"},
|
||||
{"role": "assistant", "content": "Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n
|
||||
1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n
|
||||
2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n
|
||||
3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n
|
||||
These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world.",},
|
||||
{"role": "user", "content": "What is so great about #1?"}]]
|
||||
response = pn(pn_input)
|
||||
print(response)
|
||||
```
|
||||
|
||||
|
||||
|
||||
Llama-2 models support the following inference payload parameters:
|
||||
|
||||
max_new_tokens: Model generates text until the output length (excluding the input context length) reaches
|
||||
max_new_tokens. If specified, it must be a positive integer.
|
||||
temperature: Controls the randomness in the output. Higher temperature results in output sequence with
|
||||
low-probability words and lower temperature results in output sequence with high-probability words.
|
||||
If temperature -> 0, it results in greedy decoding. If specified, it must be a positive float.
|
||||
top_p: In each step of text generation, sample from the smallest possible set of words with cumulative
|
||||
probability top_p. If specified, it must be a float between 0 and 1.
|
||||
return_full_text: If True, input text will be part of the output generated text. If specified, it must be
|
||||
boolean. The default value for it is False.
|
||||
|
||||
Of course, in both examples your endpoints, region names and other settings will be different.
|
||||
You can find it in the SageMaker AWS console.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
max_length: int = 100,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Instantiates the session with SageMaker using IAM based authentication via boto3.
|
||||
|
||||
:param model_name_or_path: The name for SageMaker Model Endpoint.
|
||||
:param max_length: The maximum length of the output text.
|
||||
:param aws_access_key_id: AWS access key ID.
|
||||
:param aws_secret_access_key: AWS secret access key.
|
||||
:param aws_session_token: AWS session token.
|
||||
:param aws_region_name: AWS region name.
|
||||
:param aws_profile_name: AWS profile name.
|
||||
"""
|
||||
# We set the default max_length to 4096 as this is the context window size supported by the LLama-2 model
|
||||
kwargs.setdefault("model_max_length", 4096)
|
||||
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
|
||||
try:
|
||||
session = SageMakerMetaInvocationLayer.create_session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
)
|
||||
self.client = session.client("sagemaker-runtime")
|
||||
except Exception as e:
|
||||
raise SageMakerInferenceError(
|
||||
f"Could not connect to SageMaker Inference Endpoint {model_name_or_path}."
|
||||
f"Make sure the Endpoint exists and AWS environment is configured."
|
||||
) from e
|
||||
|
||||
# save the kwargs for the model invocation
|
||||
self.model_input_kwargs = kwargs
|
||||
|
||||
# As of July 24, SageMaker Meta layer does not support streaming responses.
|
||||
# However, even though it's not provided, users may attempt to use streaming responses.
|
||||
# Use stream and stream_handler for warning and future use
|
||||
self.stream_handler = kwargs.get("stream_handler", None)
|
||||
self.stream = kwargs.get("stream", False)
|
||||
|
||||
def invoke(self, *args, **kwargs) -> List[str]:
|
||||
"""
|
||||
Sends the prompt to the remote model and returns the generated response(s).
|
||||
|
||||
:return: The generated responses from the model as a list of strings.
|
||||
"""
|
||||
prompt: Any = kwargs.get("prompt")
|
||||
if not prompt or not isinstance(prompt, (str, list)):
|
||||
raise ValueError(
|
||||
f"No valid prompt provided. Model {self.model_name_or_path} requires a valid prompt."
|
||||
f"Make sure to provide a prompt in the format that the model expects."
|
||||
)
|
||||
|
||||
if not (isinstance(prompt, str) or self.is_proper_chat_conversation_format(prompt)):
|
||||
raise ValueError(
|
||||
f"The prompt format is different than what the model expects. "
|
||||
f"The model {self.model_name_or_path} requires either a string or messages in the specific chat format. "
|
||||
f"For more details, see https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L213)."
|
||||
)
|
||||
|
||||
stream = kwargs.get("stream", self.stream)
|
||||
stream_handler = kwargs.get("stream_handler", self.stream_handler)
|
||||
streaming_requested = stream or stream_handler is not None
|
||||
if streaming_requested:
|
||||
raise SageMakerConfigurationError("SageMaker model response streaming is not supported yet")
|
||||
|
||||
kwargs_with_defaults = self.model_input_kwargs
|
||||
kwargs_with_defaults.update(kwargs)
|
||||
|
||||
default_params = {
|
||||
"max_new_tokens": self.max_length,
|
||||
"return_full_text": None,
|
||||
"temperature": None,
|
||||
"top_p": None,
|
||||
}
|
||||
# put the param in the params if it's in kwargs and not None (e.g. it is actually defined)
|
||||
# endpoint doesn't tolerate None values, send only the params that are defined
|
||||
params = {
|
||||
param: kwargs_with_defaults.get(param, default)
|
||||
for param, default in default_params.items()
|
||||
if param in kwargs_with_defaults or default is not None
|
||||
}
|
||||
generated_texts = self._post(prompt=prompt, params=params)
|
||||
return generated_texts
|
||||
|
||||
def _post(self, prompt: Any, params: Optional[Dict[str, Any]] = None) -> List[str]:
|
||||
"""
|
||||
Post data to the SageMaker inference model. It takes in a prompt and returns a list of responses using model
|
||||
invocation.
|
||||
:param prompt: The prompt text/messages to be sent to the model.
|
||||
:param params: The parameters to be sent to the Meta model.
|
||||
:return: The generated responses as a list of strings.
|
||||
"""
|
||||
custom_attributes = SageMakerBaseInvocationLayer.format_custom_attributes(
|
||||
self.model_input_kwargs.get("aws_custom_attributes", {})
|
||||
)
|
||||
try:
|
||||
body = {"inputs": prompt, "parameters": params}
|
||||
response = self.client.invoke_endpoint(
|
||||
EndpointName=self.model_name_or_path,
|
||||
Body=json.dumps(body),
|
||||
ContentType="application/json",
|
||||
Accept="application/json",
|
||||
CustomAttributes=custom_attributes,
|
||||
)
|
||||
response_json = response.get("Body").read().decode("utf-8")
|
||||
output = json.loads(response_json)
|
||||
generated_texts = [o["generation"] for o in output if "generation" in o]
|
||||
return generated_texts
|
||||
except requests.HTTPError as err:
|
||||
res = err.response
|
||||
if res.status_code == 429:
|
||||
raise SageMakerModelNotReadyError(f"Model not ready: {res.text}") from err
|
||||
raise SageMakerInferenceError(
|
||||
f"SageMaker Inference returned an error.\nStatus code: {res.status_code}\nResponse body: {res.text}",
|
||||
status_code=res.status_code,
|
||||
) from err
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
# the prompt for this model will be of the type str
|
||||
if isinstance(prompt, str):
|
||||
return super()._ensure_token_limit(prompt)
|
||||
else:
|
||||
# TODO: implement truncation for the chat format
|
||||
return prompt
|
||||
|
||||
def _is_proper_chat_message_format(self, chat_message: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Checks whether a chat message is in the proper format.
|
||||
:param chat_message: The chat message to be checked.
|
||||
:return: True if the chat message is in the proper format, False otherwise.
|
||||
"""
|
||||
allowed_roles = {"user", "assistant", "system"}
|
||||
return (
|
||||
isinstance(chat_message, dict)
|
||||
and "role" in chat_message
|
||||
and "content" in chat_message
|
||||
and chat_message["role"] in allowed_roles
|
||||
)
|
||||
|
||||
def is_proper_chat_conversation_format(self, prompt: List[Any]) -> bool:
|
||||
"""
|
||||
Checks whether a chat conversation is in the proper format.
|
||||
:param prompt: The chat conversation to be checked.
|
||||
:return: True if the chat conversation is in the proper format, False otherwise.
|
||||
"""
|
||||
if not isinstance(prompt, list) or len(prompt) == 0:
|
||||
return False
|
||||
|
||||
return all(
|
||||
isinstance(message_list, list)
|
||||
and all(self._is_proper_chat_message_format(chat_message) for chat_message in message_list)
|
||||
for message_list in prompt
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_test_payload(cls) -> Dict[str, Any]:
|
||||
"""
|
||||
Return test payload for the model.
|
||||
"""
|
||||
# implement the abstract method to fulfill the contract, but it won't be used
|
||||
# because we override the supports method to check support
|
||||
# for the chat and instruction following format manually
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path: str, **kwargs) -> bool:
|
||||
"""
|
||||
Checks whether a model_name_or_path passed down (e.g. via PromptNode) is supported by this class.
|
||||
|
||||
:param model_name_or_path: The model_name_or_path to check.
|
||||
"""
|
||||
aws_configuration_keys = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_profile_name",
|
||||
]
|
||||
aws_config_provided = any(key in kwargs for key in aws_configuration_keys)
|
||||
accept_eula = False
|
||||
if "aws_custom_attributes" in kwargs and isinstance(kwargs["aws_custom_attributes"], dict):
|
||||
accept_eula = kwargs["aws_custom_attributes"].get("accept_eula", False)
|
||||
if aws_config_provided and accept_eula:
|
||||
boto3_import.check()
|
||||
# attempt to create a session with the provided credentials
|
||||
session = cls.check_aws_connect(aws_configuration_keys, kwargs)
|
||||
# is endpoint in service?
|
||||
cls.check_endpoint_in_service(session, model_name_or_path)
|
||||
|
||||
# let's test both formats as we want to support both chat and instruction following models:
|
||||
# 1. instruction following format
|
||||
# send test payload to endpoint to see if it's supported
|
||||
instruction_test_payload: Dict[str, Any] = {
|
||||
"inputs": "Hello world",
|
||||
# don't remove max_new_tokens param, if we don't specify it, the model will generate 4k tokens
|
||||
"parameters": {"max_new_tokens": 10},
|
||||
}
|
||||
supported = cls.check_model_input_format(session, model_name_or_path, instruction_test_payload, **kwargs)
|
||||
if supported:
|
||||
return True
|
||||
|
||||
# 2. chat format
|
||||
chat_test_payload: Dict[str, Any] = {
|
||||
"inputs": [[{"role": "user", "content": "what is the recipe of mayonnaise?"}]],
|
||||
# don't remove max_new_tokens param, if we don't specify it, the model will generate 4k tokens
|
||||
"parameters": {"max_new_tokens": 10},
|
||||
}
|
||||
supported = cls.check_model_input_format(session, model_name_or_path, chat_test_payload, **kwargs)
|
||||
return supported
|
||||
return False
|
||||
@ -7,8 +7,13 @@ from haystack.nodes.prompt.invocation_layer import HFLocalInvocationLayer, HFInf
|
||||
@pytest.mark.unit
|
||||
def test_invocation_layer_order():
|
||||
"""
|
||||
Checks that the huggingface invocation layer is checked late because it can timeout/be slow to respond.
|
||||
Checks that the huggingface invocation layer is positioned further down the list of providers
|
||||
as they can time out or be slow to respond.
|
||||
"""
|
||||
last_invocation_layers = set(PromptModelInvocationLayer.invocation_layer_providers[-5:])
|
||||
assert HFLocalInvocationLayer in last_invocation_layers
|
||||
assert HFInferenceEndpointInvocationLayer in last_invocation_layers
|
||||
invocation_layers = PromptModelInvocationLayer.invocation_layer_providers
|
||||
assert HFLocalInvocationLayer in invocation_layers
|
||||
assert HFInferenceEndpointInvocationLayer in invocation_layers
|
||||
index_hf = invocation_layers.index(HFLocalInvocationLayer) + 1
|
||||
index_hf_inference = invocation_layers.index(HFInferenceEndpointInvocationLayer) + 1
|
||||
assert index_hf > len(invocation_layers) / 2
|
||||
assert index_hf_inference > len(invocation_layers) / 2
|
||||
|
||||
473
test/prompt/invocation_layer/test_sagemaker_meta.py
Normal file
473
test/prompt/invocation_layer/test_sagemaker_meta.py
Normal file
@ -0,0 +1,473 @@
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
from haystack.errors import SageMakerConfigurationError
|
||||
from haystack.nodes.prompt.invocation_layer import SageMakerMetaInvocationLayer
|
||||
|
||||
with LazyImport() as boto3_import:
|
||||
from botocore.exceptions import BotoCoreError
|
||||
|
||||
|
||||
# create a fixture with mocked boto3 client and session
|
||||
@pytest.fixture
|
||||
def mock_boto3_session():
|
||||
with patch("boto3.Session") as mock_client:
|
||||
yield mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompt_handler():
|
||||
with patch("haystack.nodes.prompt.invocation_layer.handlers.DefaultPromptHandler") as mock_prompt_handler:
|
||||
yield mock_prompt_handler
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the default constructor sets the correct values
|
||||
"""
|
||||
|
||||
layer = SageMakerMetaInvocationLayer(
|
||||
model_name_or_path="some_fake_model",
|
||||
max_length=99,
|
||||
aws_access_key_id="some_fake_id",
|
||||
aws_secret_access_key="some_fake_key",
|
||||
aws_session_token="some_fake_token",
|
||||
aws_profile_name="some_fake_profile",
|
||||
aws_region_name="fake_region",
|
||||
)
|
||||
|
||||
assert layer.max_length == 99
|
||||
assert layer.model_name_or_path == "some_fake_model"
|
||||
|
||||
assert layer.prompt_handler is not None
|
||||
assert layer.prompt_handler.model_max_length == 4096
|
||||
|
||||
# assert mocked boto3 client called exactly once
|
||||
mock_boto3_session.assert_called_once()
|
||||
|
||||
# assert mocked boto3 client was called with the correct parameters
|
||||
mock_boto3_session.assert_called_with(
|
||||
aws_access_key_id="some_fake_id",
|
||||
aws_secret_access_key="some_fake_key",
|
||||
aws_session_token="some_fake_token",
|
||||
profile_name="some_fake_profile",
|
||||
region_name="fake_region",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model", prompt_handler=mock_prompt_handler)
|
||||
assert layer.prompt_handler is not None
|
||||
assert layer.prompt_handler.model_max_length == 4096
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that model_kwargs are correctly set in the constructor
|
||||
"""
|
||||
model_kwargs = {"temperature": 0.7}
|
||||
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model", **model_kwargs)
|
||||
assert "temperature" in layer.model_input_kwargs
|
||||
assert layer.model_input_kwargs["temperature"] == 0.7
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_constructor_with_empty_model_name():
|
||||
"""
|
||||
Test that the constructor raises an error when the model_name_or_path is empty
|
||||
"""
|
||||
with pytest.raises(ValueError, match="cannot be None or empty string"):
|
||||
SageMakerMetaInvocationLayer(model_name_or_path="")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test invoke raises an error if no prompt is provided
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
with pytest.raises(ValueError) as e:
|
||||
layer.invoke()
|
||||
assert e.match("No prompt provided.")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_with_stop_words(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
SageMakerMetaInvocationLayer does not support stop words. Tests that they'll be ignored
|
||||
"""
|
||||
stop_words = ["but", "not", "bye"]
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_model")
|
||||
with patch("haystack.nodes.prompt.invocation_layer.SageMakerMetaInvocationLayer._post") as mock_post:
|
||||
# Mock the response, need to return a list of dicts
|
||||
mock_post.return_value = MagicMock(text='[{"generated_text": "Hello"}]')
|
||||
|
||||
layer.invoke(prompt="Tell me hello", stop_words=stop_words)
|
||||
|
||||
assert mock_post.called
|
||||
_, call_kwargs = mock_post.call_args
|
||||
assert "stop_words" not in call_kwargs["params"]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_short_prompt_is_not_truncated(mock_boto3_session):
|
||||
"""
|
||||
Test that a short prompt is not truncated
|
||||
"""
|
||||
# Define a short mock prompt and its tokenized version
|
||||
mock_prompt_text = "I am a tokenized prompt"
|
||||
mock_prompt_tokens = mock_prompt_text.split()
|
||||
|
||||
# Mock the tokenizer so it returns our predefined tokens
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.tokenize.return_value = mock_prompt_tokens
|
||||
|
||||
# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
|
||||
# Since our mock prompt is 5 tokens long, it doesn't exceed the
|
||||
# total limit (5 prompt tokens + 3 generated tokens < 10 tokens)
|
||||
max_length_generated_text = 3
|
||||
total_model_max_length = 10
|
||||
|
||||
with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
|
||||
layer = SageMakerMetaInvocationLayer(
|
||||
"some_fake_endpoint", max_length=max_length_generated_text, model_max_length=total_model_max_length
|
||||
)
|
||||
prompt_after_resize = layer._ensure_token_limit(mock_prompt_text)
|
||||
|
||||
# The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it
|
||||
assert prompt_after_resize == mock_prompt_text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_long_prompt_is_truncated(mock_boto3_session):
|
||||
"""
|
||||
Test that a long prompt is truncated
|
||||
"""
|
||||
# Define a long mock prompt and its tokenized version
|
||||
long_prompt_text = "I am a tokenized prompt of length eight"
|
||||
long_prompt_tokens = long_prompt_text.split()
|
||||
|
||||
# _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit
|
||||
truncated_prompt_text = "I am a tokenized prompt of length"
|
||||
|
||||
# Mock the tokenizer to return our predefined tokens
|
||||
# convert tokens to our predefined truncated text
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.tokenize.return_value = long_prompt_tokens
|
||||
mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text
|
||||
|
||||
# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
|
||||
# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
|
||||
max_length_generated_text = 3
|
||||
total_model_max_length = 10
|
||||
|
||||
with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
|
||||
layer = SageMakerMetaInvocationLayer(
|
||||
"some_fake_endpoint", max_length=max_length_generated_text, model_max_length=total_model_max_length
|
||||
)
|
||||
prompt_after_resize = layer._ensure_token_limit(long_prompt_text)
|
||||
|
||||
# The prompt exceeds the limit, _ensure_token_limit truncates it
|
||||
assert prompt_after_resize == truncated_prompt_text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_init_kwarg(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test stream parameter passed as init kwarg raises an error on layer invocation
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant", stream=True)
|
||||
|
||||
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_invoke_kwarg(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test stream parameter passed as invoke kwarg raises an error on layer invocation
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant")
|
||||
|
||||
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
|
||||
layer.invoke(prompt="Tell me hello", stream=True)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_handler_init_kwarg(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test stream_handler parameter passed as init kwarg raises an error on layer invocation
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant", stream_handler=Mock())
|
||||
|
||||
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
|
||||
layer.invoke(prompt="Tell me hello")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_streaming_handler_invoke_kwarg(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test stream_handler parameter passed as invoke kwarg raises an error on layer invocation
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="irrelevant")
|
||||
|
||||
with pytest.raises(SageMakerConfigurationError, match="SageMaker model response streaming is not supported yet"):
|
||||
layer.invoke(prompt="Tell me hello", stream_handler=Mock())
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports_for_valid_aws_configuration():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer identifies a valid SageMaker Inference endpoint via the supports() method
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
mock_client.describe_endpoint.return_value = {"EndpointStatus": "InService"}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_client
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
model_name_or_path="some_sagemaker_deployed_model",
|
||||
aws_profile_name="some_real_profile",
|
||||
aws_custom_attributes={"accept_eula": True},
|
||||
)
|
||||
args, kwargs = mock_client.describe_endpoint.call_args
|
||||
assert kwargs["EndpointName"] == "some_sagemaker_deployed_model"
|
||||
|
||||
args, kwargs = mock_session.client.call_args
|
||||
assert args[0] == "sagemaker-runtime"
|
||||
assert supported
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports_not_on_invalid_aws_profile_name():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer raises SageMakerConfigurationError when the profile name is invalid
|
||||
"""
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_boto3_session.side_effect = BotoCoreError()
|
||||
with pytest.raises(SageMakerConfigurationError) as exc_info:
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
model_name_or_path="some_fake_model",
|
||||
aws_profile_name="some_fake_profile",
|
||||
aws_custom_attributes={"accept_eula": True},
|
||||
)
|
||||
assert "Failed to initialize the session" in exc_info.value
|
||||
assert not supported
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports_not_on_missing_eula():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer is not supported when the EULA is missing
|
||||
"""
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.describe_endpoint.return_value = {"EndpointStatus": "InService"}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_client
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
model_name_or_path="some_sagemaker_deployed_model", aws_profile_name="some_real_profile"
|
||||
)
|
||||
|
||||
assert not supported
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_supports_not_on_eula_not_accepted():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer is not supported when the EULA is not accepted
|
||||
"""
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.describe_endpoint.return_value = {"EndpointStatus": "InService"}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.client.return_value = mock_client
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
model_name_or_path="some_sagemaker_deployed_model",
|
||||
aws_profile_name="some_real_profile",
|
||||
aws_custom_attributes={"accept_eula": False},
|
||||
)
|
||||
assert not supported
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_custom_attributes_with_non_empty_dict():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes specified
|
||||
"""
|
||||
attributes = {"key1": "value1", "key2": "value2"}
|
||||
expected_output = "key1=value1;key2=value2"
|
||||
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_custom_attributes_with_empty_dict():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes not specified
|
||||
"""
|
||||
attributes = {}
|
||||
expected_output = ""
|
||||
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_custom_attributes_with_none():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are None
|
||||
"""
|
||||
attributes = None
|
||||
expected_output = ""
|
||||
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_custom_attributes_with_bool_value():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are bool
|
||||
"""
|
||||
attributes = {"key1": True, "key2": False}
|
||||
expected_output = "key1=true;key2=false"
|
||||
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_custom_attributes_with_single_bool_value():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are single bool
|
||||
"""
|
||||
attributes = {"key1": True}
|
||||
expected_output = "key1=true"
|
||||
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_custom_attributes_with_int_value():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer correctly formats the custom attributes, attributes are ints
|
||||
"""
|
||||
attributes = {"key1": 1, "key2": 2}
|
||||
expected_output = "key1=1;key2=2"
|
||||
assert SageMakerMetaInvocationLayer.format_custom_attributes(attributes) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_chat_format(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer accepts a chat in the correct format
|
||||
"""
|
||||
# test the format of the chat, no exception should be raised
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
prompt = [[{"role": "user", "content": "Hello"}]]
|
||||
expected_response = [[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hello there"}]]
|
||||
|
||||
with patch("haystack.nodes.prompt.invocation_layer.sagemaker_meta.SageMakerMetaInvocationLayer._post") as mock_post:
|
||||
mock_post.return_value = expected_response
|
||||
layer.invoke(prompt=prompt)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_invalid_chat_format(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer raises an exception when the chat is in the wrong format
|
||||
"""
|
||||
# test the invalid format of the chat, should raise an exception
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
prompt = [{"roe": "user", "cotent": "Hello"}]
|
||||
expected_response = [[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hello there"}]]
|
||||
|
||||
with patch("haystack.nodes.prompt.invocation_layer.sagemaker_meta.SageMakerMetaInvocationLayer._post") as mock_post:
|
||||
mock_post.return_value = expected_response
|
||||
with pytest.raises(ValueError, match="The prompt format is different than what the model expects"):
|
||||
layer.invoke(prompt=prompt)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_prompt_string(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer accepts a prompt in the correct string format
|
||||
"""
|
||||
# test the format of the prompt instruction, no exception should be raised
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
with patch("haystack.nodes.prompt.invocation_layer.sagemaker_meta.SageMakerMetaInvocationLayer._post") as mock_post:
|
||||
mock_post.return_value = ["Hello there"]
|
||||
layer.invoke(prompt="Hello")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_empty_prompt(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer raises an exception when the prompt is empty string
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
with pytest.raises(ValueError):
|
||||
layer.invoke(prompt="")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_invoke_improper_prompt_type(mock_auto_tokenizer, mock_boto3_session):
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer raises an exception when the prompt is int instead of str
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
prompt = 123
|
||||
with pytest.raises(ValueError):
|
||||
layer.invoke(prompt=prompt)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("TEST_SAGEMAKER_MODEL_ENDPOINT", None), reason="Skipping because SageMaker not configured"
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_supports_triggered_for_valid_sagemaker_endpoint():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer identifies a valid SageMaker Inference endpoint via the supports() method
|
||||
"""
|
||||
model_name_or_path = os.environ.get("TEST_SAGEMAKER_MODEL_ENDPOINT")
|
||||
assert SageMakerMetaInvocationLayer.supports(model_name_or_path=model_name_or_path)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("TEST_SAGEMAKER_MODEL_ENDPOINT", None), reason="Skipping because SageMaker not configured"
|
||||
)
|
||||
@pytest.mark.integration
|
||||
def test_supports_not_triggered_for_invalid_iam_profile():
|
||||
"""
|
||||
Test that the SageMakerMetaInvocationLayer identifies an invalid SageMaker Inference endpoint
|
||||
(in this case because of an invalid IAM AWS Profile via the supports() method)
|
||||
"""
|
||||
assert not SageMakerMetaInvocationLayer.supports(model_name_or_path="fake_endpoint")
|
||||
assert not SageMakerMetaInvocationLayer.supports(
|
||||
model_name_or_path="fake_endpoint", aws_profile_name="invalid-profile"
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user