feat: add Amazon Bedrock support (#6226)

* Add Bedrock

* Update supported models for Bedrock

* Fix supports and add extract response in Bedrock

* fix errors imports

* improve and refactor supports

* fix install

* fix mypy

* fix pylint

* fix existing tests

* Added Anthropic Bedrock

* fix tests

* fix sagemaker tests

* add default prompt handler, constructor and supports tests

* more tests

* invoke refactoring

* refactor model_kwargs

* fix mypy

* lstrip responses

* Add streaming support

* bump boto3 version

* add class docstrings, better exception names

* fix layer name

* add tests for anthropic and cohere model adapters

* update cohere params

* update ai21 args and add tests

* support cohere command light model

* add tital tests

* better class names

* support meta llama 2 model

* fix streaming support

* more future-proof model adapter selection

* fix import

* fix mypy

* fix pylint for preview

* add tests for streaming

* add release notes

* Apply suggestions from code review

Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>

* fix format

* fix tests after msg changes

* fix streaming for cohere

---------

Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com>
Co-authored-by: tstadel <thomas.stadelmann@deepset.ai>
Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
This commit is contained in:
Vivek Silimkhan 2023-11-15 17:56:29 +05:30 committed by GitHub
parent 08ec492039
commit f998bf4a4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1564 additions and 116 deletions

View File

@ -63,10 +63,7 @@ jobs:
uses: tj-actions/changed-files@v40
with:
files: |
**/*.py
files_ignore: |
test/**
rest_api/test/**
haystack/preview/**/*.py
- uses: actions/setup-python@v4
with:

View File

@ -202,6 +202,27 @@ class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError):
"""Exception for issues that occur in the HuggingFace inference node due to unauthorized access"""
class AWSConfigurationError(NodeError):
"""Exception raised when AWS is not configured correctly"""
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, send_message_in_event=send_message_in_event)
class AmazonBedrockConfigurationError(NodeError):
"""Exception raised when AmazonBedrock node is not configured correctly"""
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, send_message_in_event=send_message_in_event)
class AmazonBedrockInferenceError(NodeError):
"""Exception for issues that occur in the Bedrock inference node"""
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, send_message_in_event=send_message_in_event)
class SageMakerInferenceError(NodeError):
"""Exception for issues that occur in the SageMaker inference node"""

View File

@ -8,6 +8,7 @@ from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicCla
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.amazon_bedrock import AmazonBedrockInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_meta import SageMakerMetaInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer

View File

@ -0,0 +1,372 @@
from abc import ABC, abstractmethod
import json
import logging
import re
from typing import Any, Optional, Dict, Type, Union, List
from haystack.errors import AWSConfigurationError, AmazonBedrockConfigurationError, AmazonBedrockInferenceError
from haystack.lazy_imports import LazyImport
from haystack.nodes.prompt.invocation_layer.aws_base import AWSBaseInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import (
DefaultPromptHandler,
DefaultTokenStreamingHandler,
TokenStreamingHandler,
)
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
from botocore.exceptions import ClientError
class BedrockModelAdapter(ABC):
"""
Base class for Amazon Bedrock model adapters.
"""
def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None:
self.model_kwargs = model_kwargs
self.max_length = max_length
@abstractmethod
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
"""Prepares the body for the Amazon Bedrock request."""
def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
"""Extracts the responses from the Amazon Bedrock response."""
completions = self._extract_completions_from_response(response_body)
responses = [completion.lstrip() for completion in completions]
return responses
def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]:
tokens: List[str] = []
for event in stream:
chunk = event.get("chunk")
if chunk:
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
token = self._extract_token_from_stream(decoded_chunk)
tokens.append(stream_handler(token, event_data=decoded_chunk))
responses = ["".join(tokens).lstrip()]
return responses
def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges the default params with the inference kwargs and model kwargs.
Includes param if it's in kwargs or its default is not None (i.e. it is actually defined).
"""
kwargs = self.model_kwargs.copy()
kwargs.update(inference_kwargs)
return {
param: kwargs.get(param, default)
for param, default in default_params.items()
if param in kwargs or default is not None
}
@abstractmethod
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
"""Extracts the responses from the Amazon Bedrock response."""
@abstractmethod
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
"""Extracts the token from a streaming chunk."""
class AnthropicClaudeAdapter(BedrockModelAdapter):
"""
Model adapter for the Anthropic's Claude model.
"""
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {
"max_tokens_to_sample": self.max_length,
"stop_sequences": ["\n\nHuman:"],
"temperature": None,
"top_p": None,
"top_k": None,
}
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
return body
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
return [response_body["completion"]]
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("completion", "")
class CohereCommandAdapter(BedrockModelAdapter):
"""
Model adapter for the Cohere's Command model.
"""
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {
"max_tokens": self.max_length,
"stop_sequences": None,
"temperature": None,
"p": None,
"k": None,
"return_likelihoods": None,
"stream": None,
"logit_bias": None,
"num_generations": None,
"truncate": None,
}
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": prompt, **params}
return body
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
responses = [generation["text"] for generation in response_body["generations"]]
return responses
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("text", "")
class AI21LabsJurassic2Adapter(BedrockModelAdapter):
"""
Model adapter for AI21 Labs' Jurassic 2 models.
"""
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {
"maxTokens": self.max_length,
"stopSequences": None,
"temperature": None,
"topP": None,
"countPenalty": None,
"presencePenalty": None,
"frequencyPenalty": None,
"numResults": None,
}
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": prompt, **params}
return body
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
responses = [completion["data"]["text"] for completion in response_body["completions"]]
return responses
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
raise NotImplementedError("Streaming is not supported for AI21 Jurassic 2 models.")
class AmazonTitanAdapter(BedrockModelAdapter):
"""
Model adapter for Amazon's Titan models.
"""
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None}
params = self._get_params(inference_kwargs, default_params)
body = {"inputText": prompt, "textGenerationConfig": params}
return body
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
responses = [result["outputText"] for result in response_body["results"]]
return responses
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("outputText", "")
class MetaLlama2ChatAdapter(BedrockModelAdapter):
"""
Model adapter for Meta's Llama 2 Chat models.
"""
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None}
params = self._get_params(inference_kwargs, default_params)
body = {"prompt": prompt, **params}
return body
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
return [response_body["generation"]]
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
return chunk.get("generation", "")
class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer):
"""
Invocation layer for Amazon Bedrock models.
"""
SUPPORTED_MODEL_PATTERNS: Dict[str, Type[BedrockModelAdapter]] = {
r"amazon.titan-text.*": AmazonTitanAdapter,
r"ai21.j2.*": AI21LabsJurassic2Adapter,
r"cohere.command.*": CohereCommandAdapter,
r"anthropic.claude.*": AnthropicClaudeAdapter,
r"meta.llama2.*": MetaLlama2ChatAdapter,
}
def __init__(
self,
model_name_or_path: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
max_length: Optional[int] = 100,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
self.max_length = max_length
try:
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
)
self.client = session.client("bedrock-runtime")
except Exception as exception:
raise AmazonBedrockConfigurationError(
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
) from exception
model_input_kwargs = kwargs
# We pop the model_max_length as it is not sent to the model
# but used to truncate the prompt if needed
model_max_length = kwargs.get("model_max_length", 4096)
# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# It is hard to determine which tokenizer to use for the SageMaker model
# so we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
)
model_apapter_cls = self.get_model_adapter(model_name_or_path=model_name_or_path)
if not model_apapter_cls:
raise AmazonBedrockConfigurationError(
f"This invocation layer doesn't support the model {model_name_or_path}."
)
self.model_adapter = model_apapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# the prompt for this model will be of the type str
if isinstance(prompt, List):
raise ValueError(
"The SageMaker invocation layer only supports a string as a prompt, "
"while currently, the prompt is a dictionary."
)
resize_info = self.prompt_handler(prompt)
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
logger.warning(
"The prompt was truncated from %s tokens to %s tokens so that the prompt length and "
"the answer length (%s tokens) fit within the model's max token limit (%s tokens). "
"Shorten the prompt or it will be cut off.",
resize_info["prompt_length"],
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
resize_info["max_length"],
resize_info["model_max_length"],
)
return str(resize_info["resized_prompt"])
@classmethod
def supports(cls, model_name_or_path, **kwargs):
model_supported = cls.get_model_adapter(model_name_or_path) is not None
if not model_supported or not cls.aws_configured(**kwargs):
return False
try:
session = cls.get_aws_session(**kwargs)
bedrock = session.client("bedrock")
foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT")
available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])]
model_ids_supporting_streaming = [
entry["modelId"]
for entry in foundation_models_response.get("modelSummaries", [])
if entry.get("responseStreamingSupported", False)
]
except AWSConfigurationError as exception:
raise AmazonBedrockConfigurationError(message=exception.message) from exception
except Exception as exception:
raise AmazonBedrockConfigurationError(
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
) from exception
model_available = model_name_or_path in available_model_ids
if not model_available:
raise AmazonBedrockConfigurationError(
f"The model {model_name_or_path} is not available in Amazon Bedrock. "
f"Make sure the model you want to use is available in the configured AWS region and you have access."
)
stream: bool = kwargs.get("stream", False)
model_supports_streaming = model_name_or_path in model_ids_supporting_streaming
if stream and not model_supports_streaming:
raise AmazonBedrockConfigurationError(
f"The model {model_name_or_path} doesn't support streaming. Remove the `stream` parameter."
)
return model_supported
def invoke(self, *args, **kwargs):
kwargs = kwargs.copy()
prompt: str = kwargs.pop("prompt", None)
stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False))
if not prompt or not isinstance(prompt, (str, list)):
raise ValueError(
f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. "
f"Make sure to provide a prompt in the format that the model expects."
)
body = self.model_adapter.prepare_body(prompt=prompt, **kwargs)
try:
if stream:
response = self.client.invoke_model_with_response_stream(
body=json.dumps(body),
modelId=self.model_name_or_path,
accept="application/json",
contentType="application/json",
)
response_stream = response["body"]
handler: TokenStreamingHandler = kwargs.get(
"stream_handler",
self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()),
)
responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler)
else:
response = self.client.invoke_model(
body=json.dumps(body),
modelId=self.model_name_or_path,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read().decode("utf-8"))
responses = self.model_adapter.get_responses(response_body=response_body)
except ClientError as exception:
raise AmazonBedrockInferenceError(
f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. "
"Make sure your AWS environment is configured correctly, "
"the model is available in the configured AWS region, and you have access."
) from exception
return responses
@classmethod
def get_model_adapter(cls, model_name_or_path: str) -> Optional[Type[BedrockModelAdapter]]:
for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
if re.fullmatch(pattern, model_name_or_path):
return adapter
return None

View File

@ -0,0 +1,79 @@
import logging
from abc import ABC
from typing import Optional
from haystack.errors import AWSConfigurationError
from haystack.lazy_imports import LazyImport
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
import boto3
from botocore.exceptions import BotoCoreError
AWS_CONFIGURATION_KEYS = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]
class AWSBaseInvocationLayer(PromptModelInvocationLayer, ABC):
"""
Base class for AWS based invocation layers.
"""
@classmethod
def aws_configured(cls, **kwargs) -> bool:
"""
Checks whether this invocation layer is active.
:param kwargs: The kwargs passed down to the invocation layer.
:return: True if the invocation layer is active, False otherwise.
"""
aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS)
return aws_config_provided
@classmethod
def get_aws_session(
cls,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Creates an AWS Session with the given parameters.
Checks if the provided AWS credentials are valid and can be used to connect to AWS.
:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen.
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html.
:raises AWSConfigurationError: If the provided AWS credentials are invalid.
:return: The created AWS session.
"""
boto3_import.check()
try:
return boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
profile_name=aws_profile_name,
)
except BotoCoreError as e:
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
raise AWSConfigurationError(
f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
) from e

View File

@ -1,12 +1,12 @@
import json
import logging
from abc import abstractmethod, ABC
from typing import Optional, Dict, Union, List, Any
from typing import Dict, Union, List, Any
from haystack.errors import SageMakerConfigurationError
from haystack.errors import AWSConfigurationError, SageMakerConfigurationError
from haystack.lazy_imports import LazyImport
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer.aws_base import AWSBaseInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
logger = logging.getLogger(__name__)
@ -14,10 +14,10 @@ logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
import boto3
from botocore.exceptions import ClientError, BotoCoreError
from botocore.exceptions import ClientError
class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
class SageMakerBaseInvocationLayer(AWSBaseInvocationLayer, ABC):
"""
Base class for SageMaker based invocation layers.
"""
@ -70,18 +70,12 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
:param model_name_or_path: The model_name_or_path to check.
"""
aws_configuration_keys = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]
aws_config_provided = any(key in kwargs for key in aws_configuration_keys)
if aws_config_provided:
boto3_import.check()
if cls.aws_configured(**kwargs):
# attempt to create a session with the provided credentials
session = cls.check_aws_connect(aws_configuration_keys, kwargs)
try:
session = cls.get_aws_session(**kwargs)
except AWSConfigurationError as e:
raise SageMakerConfigurationError(message=e.message) from e
# is endpoint in service?
cls.check_endpoint_in_service(session, model_name_or_path)
@ -91,24 +85,6 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
return supported
return False
@classmethod
def check_aws_connect(cls, aws_configuration_keys: List[str], kwargs):
"""
Checks if the provided AWS credentials are valid and can be used to connect to SageMaker.
:param aws_configuration_keys: The AWS configuration keys to check.
:param kwargs: The kwargs passed down to the SageMakerClient.
:return: The boto3 session.
"""
boto3_import.check()
try:
session = cls.create_session(**kwargs)
except BotoCoreError as e:
provided_aws_config = {k: v for k, v in kwargs.items() if k in aws_configuration_keys}
raise SageMakerConfigurationError(
f"Failed to initialize the session or client with provided AWS credentials {provided_aws_config}"
) from e
return session
@classmethod
def check_endpoint_in_service(cls, session: "boto3.Session", endpoint: str):
"""
@ -176,33 +152,3 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
if client:
client.close()
return True
@classmethod
def create_session(
cls,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Creates an AWS Session with the given parameters.
:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:raise NoCredentialsError: If the AWS credentials are not provided or invalid.
:return: The created AWS Session.
"""
boto3_import.check()
return boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
profile_name=aws_profile_name,
)

View File

@ -97,7 +97,7 @@ class SageMakerHFInferenceInvocationLayer(SageMakerBaseInvocationLayer):
"""
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
try:
session = SageMakerHFInferenceInvocationLayer.create_session(
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,

View File

@ -88,7 +88,7 @@ class SageMakerHFTextGenerationInvocationLayer(SageMakerBaseInvocationLayer):
"""
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
try:
session = SageMakerHFTextGenerationInvocationLayer.create_session(
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,

View File

@ -4,15 +4,16 @@ from typing import Optional, Dict, List, Any, Union
import requests
from haystack.errors import SageMakerModelNotReadyError, SageMakerInferenceError, SageMakerConfigurationError
from haystack.lazy_imports import LazyImport
from haystack.errors import (
AWSConfigurationError,
SageMakerModelNotReadyError,
SageMakerInferenceError,
SageMakerConfigurationError,
)
from haystack.nodes.prompt.invocation_layer.sagemaker_base import SageMakerBaseInvocationLayer
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
pass
class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
"""
@ -139,7 +140,7 @@ class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
kwargs.setdefault("model_max_length", 4096)
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
try:
session = SageMakerMetaInvocationLayer.create_session(
session = self.get_aws_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
@ -294,21 +295,15 @@ class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
:param model_name_or_path: The model_name_or_path to check.
"""
aws_configuration_keys = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_profile_name",
]
aws_config_provided = any(key in kwargs for key in aws_configuration_keys)
accept_eula = False
if "aws_custom_attributes" in kwargs and isinstance(kwargs["aws_custom_attributes"], dict):
accept_eula = kwargs["aws_custom_attributes"].get("accept_eula", False)
if aws_config_provided and accept_eula:
boto3_import.check()
if cls.aws_configured(**kwargs) and accept_eula:
# attempt to create a session with the provided credentials
session = cls.check_aws_connect(aws_configuration_keys, kwargs)
try:
session = cls.get_aws_session(**kwargs)
except AWSConfigurationError as e:
raise SageMakerConfigurationError(message=e.message) from e
# is endpoint in service?
cls.check_endpoint_in_service(session, model_name_or_path)

View File

@ -152,12 +152,8 @@ docstores-gpu = [
"farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch]",
]
aws = [
"boto3",
# Costraint botocore to avoid taking to much time to resolve the dependency tree.
# boto3 used to constraint it at this version more than a year ago. To avoid breaking
# people using old versions we use a similar constraint without upper bound.
# https://github.com/boto/boto3/blob/dae73bef223abbedfa7317a783070831febc0c90/setup.py#L16
"botocore>=1.27",
# first version to support Amazon Bedrock
"boto3>=1.28.57",
]
crawler = [
"selenium>=4.11.0"

View File

@ -0,0 +1,16 @@
---
prelude: >
Haystack now supports Amazon Bedrock models, including all existing and previously announced
models, like Llama-2-70b-chat. To use these models, simply pass the model ID in the
model_name_or_path parameter, like you do for any other model. For details, see
[Amazon Bedrock Docmentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html).
For example, the following code loads the Llama 2 Chat 13B model:
```python
from haystack.nodes import PromptNode
prompt_node = PromptNode(model_name_or_path="meta.llama2-13b-chat-v1")
```
features:
- |
You can use Amazon Bedrock models in Haystack.

File diff suppressed because it is too large Load Diff

View File

@ -15,5 +15,5 @@ def test_invocation_layer_order():
assert HFInferenceEndpointInvocationLayer in invocation_layers
index_hf = invocation_layers.index(HFLocalInvocationLayer) + 1
index_hf_inference = invocation_layers.index(HFInferenceEndpointInvocationLayer) + 1
assert index_hf > len(invocation_layers) / 2
assert index_hf_inference > len(invocation_layers) / 2
assert index_hf >= 7
assert index_hf_inference >= 7

View File

@ -223,7 +223,7 @@ def test_supports_for_valid_aws_configuration():
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
return_value=mock_session,
):
supported = SageMakerHFInferenceInvocationLayer.supports(
@ -245,12 +245,10 @@ def test_supports_not_on_invalid_aws_profile_name():
with patch("boto3.Session") as mock_boto3_session:
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(SageMakerConfigurationError) as exc_info:
supported = SageMakerHFInferenceInvocationLayer.supports(
with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"):
SageMakerHFInferenceInvocationLayer.supports(
model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile"
)
assert "Failed to initialize the session" in exc_info.value
assert not supported
@pytest.mark.skipif(

View File

@ -81,9 +81,8 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
Test that invoke raises an error if no prompt is provided
"""
layer = SageMakerHFTextGenerationInvocationLayer(model_name_or_path="some_fake_model")
with pytest.raises(ValueError) as e:
with pytest.raises(ValueError, match="No prompt provided."):
layer.invoke()
assert e.match("No prompt provided.")
@pytest.mark.unit
@ -223,7 +222,7 @@ def test_supports_for_valid_aws_configuration():
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
return_value=mock_session,
):
supported = SageMakerHFTextGenerationInvocationLayer.supports(
@ -245,12 +244,10 @@ def test_supports_not_on_invalid_aws_profile_name():
with patch("boto3.Session") as mock_boto3_session:
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(SageMakerConfigurationError) as exc_info:
supported = SageMakerHFTextGenerationInvocationLayer.supports(
with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"):
SageMakerHFTextGenerationInvocationLayer.supports(
model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile"
)
assert "Failed to initialize the session" in exc_info.value
assert not supported
@pytest.mark.skipif(

View File

@ -97,9 +97,8 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
Test invoke raises an error if no prompt is provided
"""
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
with pytest.raises(ValueError) as e:
with pytest.raises(ValueError, match="No valid prompt provided."):
layer.invoke()
assert e.match("No prompt provided.")
@pytest.mark.unit
@ -239,7 +238,7 @@ def test_supports_for_valid_aws_configuration():
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
return_value=mock_session,
):
supported = SageMakerMetaInvocationLayer.supports(
@ -263,14 +262,12 @@ def test_supports_not_on_invalid_aws_profile_name():
with patch("boto3.Session") as mock_boto3_session:
mock_boto3_session.side_effect = BotoCoreError()
with pytest.raises(SageMakerConfigurationError) as exc_info:
supported = SageMakerMetaInvocationLayer.supports(
with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"):
SageMakerMetaInvocationLayer.supports(
model_name_or_path="some_fake_model",
aws_profile_name="some_fake_profile",
aws_custom_attributes={"accept_eula": True},
)
assert "Failed to initialize the session" in exc_info.value
assert not supported
@pytest.mark.unit
@ -287,7 +284,7 @@ def test_supports_not_on_missing_eula():
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
return_value=mock_session,
):
supported = SageMakerMetaInvocationLayer.supports(
@ -311,7 +308,7 @@ def test_supports_not_on_eula_not_accepted():
# Patch the class method to return the mock session
with patch(
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
return_value=mock_session,
):
supported = SageMakerMetaInvocationLayer.supports(