mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
feat: add Amazon Bedrock support (#6226)
* Add Bedrock * Update supported models for Bedrock * Fix supports and add extract response in Bedrock * fix errors imports * improve and refactor supports * fix install * fix mypy * fix pylint * fix existing tests * Added Anthropic Bedrock * fix tests * fix sagemaker tests * add default prompt handler, constructor and supports tests * more tests * invoke refactoring * refactor model_kwargs * fix mypy * lstrip responses * Add streaming support * bump boto3 version * add class docstrings, better exception names * fix layer name * add tests for anthropic and cohere model adapters * update cohere params * update ai21 args and add tests * support cohere command light model * add tital tests * better class names * support meta llama 2 model * fix streaming support * more future-proof model adapter selection * fix import * fix mypy * fix pylint for preview * add tests for streaming * add release notes * Apply suggestions from code review Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com> * fix format * fix tests after msg changes * fix streaming for cohere --------- Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com> Co-authored-by: tstadel <thomas.stadelmann@deepset.ai> Co-authored-by: Agnieszka Marzec <97166305+agnieszka-m@users.noreply.github.com>
This commit is contained in:
parent
08ec492039
commit
f998bf4a4f
5
.github/workflows/linting_preview.yml
vendored
5
.github/workflows/linting_preview.yml
vendored
@ -63,10 +63,7 @@ jobs:
|
||||
uses: tj-actions/changed-files@v40
|
||||
with:
|
||||
files: |
|
||||
**/*.py
|
||||
files_ignore: |
|
||||
test/**
|
||||
rest_api/test/**
|
||||
haystack/preview/**/*.py
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
|
||||
@ -202,6 +202,27 @@ class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError):
|
||||
"""Exception for issues that occur in the HuggingFace inference node due to unauthorized access"""
|
||||
|
||||
|
||||
class AWSConfigurationError(NodeError):
|
||||
"""Exception raised when AWS is not configured correctly"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
||||
|
||||
|
||||
class AmazonBedrockConfigurationError(NodeError):
|
||||
"""Exception raised when AmazonBedrock node is not configured correctly"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
||||
|
||||
|
||||
class AmazonBedrockInferenceError(NodeError):
|
||||
"""Exception for issues that occur in the Bedrock inference node"""
|
||||
|
||||
def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
|
||||
super().__init__(message=message, send_message_in_event=send_message_in_event)
|
||||
|
||||
|
||||
class SageMakerInferenceError(NodeError):
|
||||
"""Exception for issues that occur in the SageMaker inference node"""
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from haystack.nodes.prompt.invocation_layer.anthropic_claude import AnthropicCla
|
||||
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.amazon_bedrock import AmazonBedrockInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_meta import SageMakerMetaInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer
|
||||
|
||||
372
haystack/nodes/prompt/invocation_layer/amazon_bedrock.py
Normal file
372
haystack/nodes/prompt/invocation_layer/amazon_bedrock.py
Normal file
@ -0,0 +1,372 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Optional, Dict, Type, Union, List
|
||||
|
||||
from haystack.errors import AWSConfigurationError, AmazonBedrockConfigurationError, AmazonBedrockInferenceError
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.nodes.prompt.invocation_layer.aws_base import AWSBaseInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import (
|
||||
DefaultPromptHandler,
|
||||
DefaultTokenStreamingHandler,
|
||||
TokenStreamingHandler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
class BedrockModelAdapter(ABC):
|
||||
"""
|
||||
Base class for Amazon Bedrock model adapters.
|
||||
"""
|
||||
|
||||
def __init__(self, model_kwargs: Dict[str, Any], max_length: Optional[int]) -> None:
|
||||
self.model_kwargs = model_kwargs
|
||||
self.max_length = max_length
|
||||
|
||||
@abstractmethod
|
||||
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
|
||||
"""Prepares the body for the Amazon Bedrock request."""
|
||||
|
||||
def get_responses(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
"""Extracts the responses from the Amazon Bedrock response."""
|
||||
completions = self._extract_completions_from_response(response_body)
|
||||
responses = [completion.lstrip() for completion in completions]
|
||||
return responses
|
||||
|
||||
def get_stream_responses(self, stream, stream_handler: TokenStreamingHandler) -> List[str]:
|
||||
tokens: List[str] = []
|
||||
for event in stream:
|
||||
chunk = event.get("chunk")
|
||||
if chunk:
|
||||
decoded_chunk = json.loads(chunk["bytes"].decode("utf-8"))
|
||||
token = self._extract_token_from_stream(decoded_chunk)
|
||||
tokens.append(stream_handler(token, event_data=decoded_chunk))
|
||||
responses = ["".join(tokens).lstrip()]
|
||||
return responses
|
||||
|
||||
def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Merges the default params with the inference kwargs and model kwargs.
|
||||
|
||||
Includes param if it's in kwargs or its default is not None (i.e. it is actually defined).
|
||||
"""
|
||||
kwargs = self.model_kwargs.copy()
|
||||
kwargs.update(inference_kwargs)
|
||||
return {
|
||||
param: kwargs.get(param, default)
|
||||
for param, default in default_params.items()
|
||||
if param in kwargs or default is not None
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
"""Extracts the responses from the Amazon Bedrock response."""
|
||||
|
||||
@abstractmethod
|
||||
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
|
||||
"""Extracts the token from a streaming chunk."""
|
||||
|
||||
|
||||
class AnthropicClaudeAdapter(BedrockModelAdapter):
|
||||
"""
|
||||
Model adapter for the Anthropic's Claude model.
|
||||
"""
|
||||
|
||||
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
|
||||
default_params = {
|
||||
"max_tokens_to_sample": self.max_length,
|
||||
"stop_sequences": ["\n\nHuman:"],
|
||||
"temperature": None,
|
||||
"top_p": None,
|
||||
"top_k": None,
|
||||
}
|
||||
params = self._get_params(inference_kwargs, default_params)
|
||||
|
||||
body = {"prompt": f"\n\nHuman: {prompt}\n\nAssistant:", **params}
|
||||
return body
|
||||
|
||||
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
return [response_body["completion"]]
|
||||
|
||||
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
|
||||
return chunk.get("completion", "")
|
||||
|
||||
|
||||
class CohereCommandAdapter(BedrockModelAdapter):
|
||||
"""
|
||||
Model adapter for the Cohere's Command model.
|
||||
"""
|
||||
|
||||
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
|
||||
default_params = {
|
||||
"max_tokens": self.max_length,
|
||||
"stop_sequences": None,
|
||||
"temperature": None,
|
||||
"p": None,
|
||||
"k": None,
|
||||
"return_likelihoods": None,
|
||||
"stream": None,
|
||||
"logit_bias": None,
|
||||
"num_generations": None,
|
||||
"truncate": None,
|
||||
}
|
||||
params = self._get_params(inference_kwargs, default_params)
|
||||
|
||||
body = {"prompt": prompt, **params}
|
||||
return body
|
||||
|
||||
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
responses = [generation["text"] for generation in response_body["generations"]]
|
||||
return responses
|
||||
|
||||
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
|
||||
return chunk.get("text", "")
|
||||
|
||||
|
||||
class AI21LabsJurassic2Adapter(BedrockModelAdapter):
|
||||
"""
|
||||
Model adapter for AI21 Labs' Jurassic 2 models.
|
||||
"""
|
||||
|
||||
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
|
||||
default_params = {
|
||||
"maxTokens": self.max_length,
|
||||
"stopSequences": None,
|
||||
"temperature": None,
|
||||
"topP": None,
|
||||
"countPenalty": None,
|
||||
"presencePenalty": None,
|
||||
"frequencyPenalty": None,
|
||||
"numResults": None,
|
||||
}
|
||||
params = self._get_params(inference_kwargs, default_params)
|
||||
|
||||
body = {"prompt": prompt, **params}
|
||||
return body
|
||||
|
||||
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
responses = [completion["data"]["text"] for completion in response_body["completions"]]
|
||||
return responses
|
||||
|
||||
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
|
||||
raise NotImplementedError("Streaming is not supported for AI21 Jurassic 2 models.")
|
||||
|
||||
|
||||
class AmazonTitanAdapter(BedrockModelAdapter):
|
||||
"""
|
||||
Model adapter for Amazon's Titan models.
|
||||
"""
|
||||
|
||||
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
|
||||
default_params = {"maxTokenCount": self.max_length, "stopSequences": None, "temperature": None, "topP": None}
|
||||
params = self._get_params(inference_kwargs, default_params)
|
||||
|
||||
body = {"inputText": prompt, "textGenerationConfig": params}
|
||||
return body
|
||||
|
||||
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
responses = [result["outputText"] for result in response_body["results"]]
|
||||
return responses
|
||||
|
||||
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
|
||||
return chunk.get("outputText", "")
|
||||
|
||||
|
||||
class MetaLlama2ChatAdapter(BedrockModelAdapter):
|
||||
"""
|
||||
Model adapter for Meta's Llama 2 Chat models.
|
||||
"""
|
||||
|
||||
def prepare_body(self, prompt: str, **inference_kwargs) -> Dict[str, Any]:
|
||||
default_params = {"max_gen_len": self.max_length, "temperature": None, "top_p": None}
|
||||
params = self._get_params(inference_kwargs, default_params)
|
||||
|
||||
body = {"prompt": prompt, **params}
|
||||
return body
|
||||
|
||||
def _extract_completions_from_response(self, response_body: Dict[str, Any]) -> List[str]:
|
||||
return [response_body["generation"]]
|
||||
|
||||
def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str:
|
||||
return chunk.get("generation", "")
|
||||
|
||||
|
||||
class AmazonBedrockInvocationLayer(AWSBaseInvocationLayer):
|
||||
"""
|
||||
Invocation layer for Amazon Bedrock models.
|
||||
"""
|
||||
|
||||
SUPPORTED_MODEL_PATTERNS: Dict[str, Type[BedrockModelAdapter]] = {
|
||||
r"amazon.titan-text.*": AmazonTitanAdapter,
|
||||
r"ai21.j2.*": AI21LabsJurassic2Adapter,
|
||||
r"cohere.command.*": CohereCommandAdapter,
|
||||
r"anthropic.claude.*": AnthropicClaudeAdapter,
|
||||
r"meta.llama2.*": MetaLlama2ChatAdapter,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
max_length: Optional[int] = 100,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(model_name_or_path, **kwargs)
|
||||
self.max_length = max_length
|
||||
|
||||
try:
|
||||
session = self.get_aws_session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
)
|
||||
self.client = session.client("bedrock-runtime")
|
||||
except Exception as exception:
|
||||
raise AmazonBedrockConfigurationError(
|
||||
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
|
||||
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
|
||||
) from exception
|
||||
|
||||
model_input_kwargs = kwargs
|
||||
# We pop the model_max_length as it is not sent to the model
|
||||
# but used to truncate the prompt if needed
|
||||
model_max_length = kwargs.get("model_max_length", 4096)
|
||||
|
||||
# Truncate prompt if prompt tokens > model_max_length-max_length
|
||||
# (max_length is the length of the generated text)
|
||||
# It is hard to determine which tokenizer to use for the SageMaker model
|
||||
# so we use GPT2 tokenizer which will likely provide good token count approximation
|
||||
self.prompt_handler = DefaultPromptHandler(
|
||||
model_name_or_path="gpt2", model_max_length=model_max_length, max_length=self.max_length or 100
|
||||
)
|
||||
|
||||
model_apapter_cls = self.get_model_adapter(model_name_or_path=model_name_or_path)
|
||||
if not model_apapter_cls:
|
||||
raise AmazonBedrockConfigurationError(
|
||||
f"This invocation layer doesn't support the model {model_name_or_path}."
|
||||
)
|
||||
self.model_adapter = model_apapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)
|
||||
|
||||
def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
|
||||
# the prompt for this model will be of the type str
|
||||
if isinstance(prompt, List):
|
||||
raise ValueError(
|
||||
"The SageMaker invocation layer only supports a string as a prompt, "
|
||||
"while currently, the prompt is a dictionary."
|
||||
)
|
||||
|
||||
resize_info = self.prompt_handler(prompt)
|
||||
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
|
||||
logger.warning(
|
||||
"The prompt was truncated from %s tokens to %s tokens so that the prompt length and "
|
||||
"the answer length (%s tokens) fit within the model's max token limit (%s tokens). "
|
||||
"Shorten the prompt or it will be cut off.",
|
||||
resize_info["prompt_length"],
|
||||
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
|
||||
resize_info["max_length"],
|
||||
resize_info["model_max_length"],
|
||||
)
|
||||
return str(resize_info["resized_prompt"])
|
||||
|
||||
@classmethod
|
||||
def supports(cls, model_name_or_path, **kwargs):
|
||||
model_supported = cls.get_model_adapter(model_name_or_path) is not None
|
||||
if not model_supported or not cls.aws_configured(**kwargs):
|
||||
return False
|
||||
|
||||
try:
|
||||
session = cls.get_aws_session(**kwargs)
|
||||
bedrock = session.client("bedrock")
|
||||
foundation_models_response = bedrock.list_foundation_models(byOutputModality="TEXT")
|
||||
available_model_ids = [entry["modelId"] for entry in foundation_models_response.get("modelSummaries", [])]
|
||||
model_ids_supporting_streaming = [
|
||||
entry["modelId"]
|
||||
for entry in foundation_models_response.get("modelSummaries", [])
|
||||
if entry.get("responseStreamingSupported", False)
|
||||
]
|
||||
except AWSConfigurationError as exception:
|
||||
raise AmazonBedrockConfigurationError(message=exception.message) from exception
|
||||
except Exception as exception:
|
||||
raise AmazonBedrockConfigurationError(
|
||||
"Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. "
|
||||
"See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration"
|
||||
) from exception
|
||||
|
||||
model_available = model_name_or_path in available_model_ids
|
||||
if not model_available:
|
||||
raise AmazonBedrockConfigurationError(
|
||||
f"The model {model_name_or_path} is not available in Amazon Bedrock. "
|
||||
f"Make sure the model you want to use is available in the configured AWS region and you have access."
|
||||
)
|
||||
|
||||
stream: bool = kwargs.get("stream", False)
|
||||
model_supports_streaming = model_name_or_path in model_ids_supporting_streaming
|
||||
if stream and not model_supports_streaming:
|
||||
raise AmazonBedrockConfigurationError(
|
||||
f"The model {model_name_or_path} doesn't support streaming. Remove the `stream` parameter."
|
||||
)
|
||||
|
||||
return model_supported
|
||||
|
||||
def invoke(self, *args, **kwargs):
|
||||
kwargs = kwargs.copy()
|
||||
prompt: str = kwargs.pop("prompt", None)
|
||||
stream: bool = kwargs.get("stream", self.model_adapter.model_kwargs.get("stream", False))
|
||||
|
||||
if not prompt or not isinstance(prompt, (str, list)):
|
||||
raise ValueError(
|
||||
f"The model {self.model_name_or_path} requires a valid prompt, but currently, it has no prompt. "
|
||||
f"Make sure to provide a prompt in the format that the model expects."
|
||||
)
|
||||
|
||||
body = self.model_adapter.prepare_body(prompt=prompt, **kwargs)
|
||||
try:
|
||||
if stream:
|
||||
response = self.client.invoke_model_with_response_stream(
|
||||
body=json.dumps(body),
|
||||
modelId=self.model_name_or_path,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_stream = response["body"]
|
||||
handler: TokenStreamingHandler = kwargs.get(
|
||||
"stream_handler",
|
||||
self.model_adapter.model_kwargs.get("stream_handler", DefaultTokenStreamingHandler()),
|
||||
)
|
||||
responses = self.model_adapter.get_stream_responses(stream=response_stream, stream_handler=handler)
|
||||
else:
|
||||
response = self.client.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId=self.model_name_or_path,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read().decode("utf-8"))
|
||||
responses = self.model_adapter.get_responses(response_body=response_body)
|
||||
except ClientError as exception:
|
||||
raise AmazonBedrockInferenceError(
|
||||
f"Could not connect to Amazon Bedrock model {self.model_name_or_path}. "
|
||||
"Make sure your AWS environment is configured correctly, "
|
||||
"the model is available in the configured AWS region, and you have access."
|
||||
) from exception
|
||||
|
||||
return responses
|
||||
|
||||
@classmethod
|
||||
def get_model_adapter(cls, model_name_or_path: str) -> Optional[Type[BedrockModelAdapter]]:
|
||||
for pattern, adapter in cls.SUPPORTED_MODEL_PATTERNS.items():
|
||||
if re.fullmatch(pattern, model_name_or_path):
|
||||
return adapter
|
||||
return None
|
||||
79
haystack/nodes/prompt/invocation_layer/aws_base.py
Normal file
79
haystack/nodes/prompt/invocation_layer/aws_base.py
Normal file
@ -0,0 +1,79 @@
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
|
||||
from haystack.errors import AWSConfigurationError
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError
|
||||
|
||||
|
||||
AWS_CONFIGURATION_KEYS = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_profile_name",
|
||||
]
|
||||
|
||||
|
||||
class AWSBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
"""
|
||||
Base class for AWS based invocation layers.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def aws_configured(cls, **kwargs) -> bool:
|
||||
"""
|
||||
Checks whether this invocation layer is active.
|
||||
:param kwargs: The kwargs passed down to the invocation layer.
|
||||
:return: True if the invocation layer is active, False otherwise.
|
||||
"""
|
||||
aws_config_provided = any(key in kwargs for key in AWS_CONFIGURATION_KEYS)
|
||||
return aws_config_provided
|
||||
|
||||
@classmethod
|
||||
def get_aws_session(
|
||||
cls,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates an AWS Session with the given parameters.
|
||||
Checks if the provided AWS credentials are valid and can be used to connect to AWS.
|
||||
|
||||
:param aws_access_key_id: AWS access key ID.
|
||||
:param aws_secret_access_key: AWS secret access key.
|
||||
:param aws_session_token: AWS session token.
|
||||
:param aws_region_name: AWS region name.
|
||||
:param aws_profile_name: AWS profile name.
|
||||
:param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen.
|
||||
See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html.
|
||||
:raises AWSConfigurationError: If the provided AWS credentials are invalid.
|
||||
:return: The created AWS session.
|
||||
"""
|
||||
boto3_import.check()
|
||||
try:
|
||||
return boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
region_name=aws_region_name,
|
||||
profile_name=aws_profile_name,
|
||||
)
|
||||
except BotoCoreError as e:
|
||||
provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS}
|
||||
raise AWSConfigurationError(
|
||||
f"Failed to initialize the session with provided AWS credentials {provided_aws_config}"
|
||||
) from e
|
||||
@ -1,12 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Optional, Dict, Union, List, Any
|
||||
from typing import Dict, Union, List, Any
|
||||
|
||||
|
||||
from haystack.errors import SageMakerConfigurationError
|
||||
from haystack.errors import AWSConfigurationError, SageMakerConfigurationError
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.aws_base import AWSBaseInvocationLayer
|
||||
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -14,10 +14,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, BotoCoreError
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
|
||||
class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
class SageMakerBaseInvocationLayer(AWSBaseInvocationLayer, ABC):
|
||||
"""
|
||||
Base class for SageMaker based invocation layers.
|
||||
"""
|
||||
@ -70,18 +70,12 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
|
||||
:param model_name_or_path: The model_name_or_path to check.
|
||||
"""
|
||||
aws_configuration_keys = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_profile_name",
|
||||
]
|
||||
aws_config_provided = any(key in kwargs for key in aws_configuration_keys)
|
||||
if aws_config_provided:
|
||||
boto3_import.check()
|
||||
if cls.aws_configured(**kwargs):
|
||||
# attempt to create a session with the provided credentials
|
||||
session = cls.check_aws_connect(aws_configuration_keys, kwargs)
|
||||
try:
|
||||
session = cls.get_aws_session(**kwargs)
|
||||
except AWSConfigurationError as e:
|
||||
raise SageMakerConfigurationError(message=e.message) from e
|
||||
# is endpoint in service?
|
||||
cls.check_endpoint_in_service(session, model_name_or_path)
|
||||
|
||||
@ -91,24 +85,6 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
return supported
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_aws_connect(cls, aws_configuration_keys: List[str], kwargs):
|
||||
"""
|
||||
Checks if the provided AWS credentials are valid and can be used to connect to SageMaker.
|
||||
:param aws_configuration_keys: The AWS configuration keys to check.
|
||||
:param kwargs: The kwargs passed down to the SageMakerClient.
|
||||
:return: The boto3 session.
|
||||
"""
|
||||
boto3_import.check()
|
||||
try:
|
||||
session = cls.create_session(**kwargs)
|
||||
except BotoCoreError as e:
|
||||
provided_aws_config = {k: v for k, v in kwargs.items() if k in aws_configuration_keys}
|
||||
raise SageMakerConfigurationError(
|
||||
f"Failed to initialize the session or client with provided AWS credentials {provided_aws_config}"
|
||||
) from e
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def check_endpoint_in_service(cls, session: "boto3.Session", endpoint: str):
|
||||
"""
|
||||
@ -176,33 +152,3 @@ class SageMakerBaseInvocationLayer(PromptModelInvocationLayer, ABC):
|
||||
if client:
|
||||
client.close()
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def create_session(
|
||||
cls,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates an AWS Session with the given parameters.
|
||||
|
||||
:param aws_access_key_id: AWS access key ID.
|
||||
:param aws_secret_access_key: AWS secret access key.
|
||||
:param aws_session_token: AWS session token.
|
||||
:param aws_region_name: AWS region name.
|
||||
:param aws_profile_name: AWS profile name.
|
||||
:raise NoCredentialsError: If the AWS credentials are not provided or invalid.
|
||||
:return: The created AWS Session.
|
||||
"""
|
||||
boto3_import.check()
|
||||
return boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
region_name=aws_region_name,
|
||||
profile_name=aws_profile_name,
|
||||
)
|
||||
|
||||
@ -97,7 +97,7 @@ class SageMakerHFInferenceInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
"""
|
||||
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
|
||||
try:
|
||||
session = SageMakerHFInferenceInvocationLayer.create_session(
|
||||
session = self.get_aws_session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
|
||||
@ -88,7 +88,7 @@ class SageMakerHFTextGenerationInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
"""
|
||||
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
|
||||
try:
|
||||
session = SageMakerHFTextGenerationInvocationLayer.create_session(
|
||||
session = self.get_aws_session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
|
||||
@ -4,15 +4,16 @@ from typing import Optional, Dict, List, Any, Union
|
||||
|
||||
import requests
|
||||
|
||||
from haystack.errors import SageMakerModelNotReadyError, SageMakerInferenceError, SageMakerConfigurationError
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.errors import (
|
||||
AWSConfigurationError,
|
||||
SageMakerModelNotReadyError,
|
||||
SageMakerInferenceError,
|
||||
SageMakerConfigurationError,
|
||||
)
|
||||
from haystack.nodes.prompt.invocation_layer.sagemaker_base import SageMakerBaseInvocationLayer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
|
||||
pass
|
||||
|
||||
|
||||
class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
"""
|
||||
@ -139,7 +140,7 @@ class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
kwargs.setdefault("model_max_length", 4096)
|
||||
super().__init__(model_name_or_path, max_length=max_length, **kwargs)
|
||||
try:
|
||||
session = SageMakerMetaInvocationLayer.create_session(
|
||||
session = self.get_aws_session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
@ -294,21 +295,15 @@ class SageMakerMetaInvocationLayer(SageMakerBaseInvocationLayer):
|
||||
|
||||
:param model_name_or_path: The model_name_or_path to check.
|
||||
"""
|
||||
aws_configuration_keys = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_profile_name",
|
||||
]
|
||||
aws_config_provided = any(key in kwargs for key in aws_configuration_keys)
|
||||
accept_eula = False
|
||||
if "aws_custom_attributes" in kwargs and isinstance(kwargs["aws_custom_attributes"], dict):
|
||||
accept_eula = kwargs["aws_custom_attributes"].get("accept_eula", False)
|
||||
if aws_config_provided and accept_eula:
|
||||
boto3_import.check()
|
||||
if cls.aws_configured(**kwargs) and accept_eula:
|
||||
# attempt to create a session with the provided credentials
|
||||
session = cls.check_aws_connect(aws_configuration_keys, kwargs)
|
||||
try:
|
||||
session = cls.get_aws_session(**kwargs)
|
||||
except AWSConfigurationError as e:
|
||||
raise SageMakerConfigurationError(message=e.message) from e
|
||||
# is endpoint in service?
|
||||
cls.check_endpoint_in_service(session, model_name_or_path)
|
||||
|
||||
|
||||
@ -152,12 +152,8 @@ docstores-gpu = [
|
||||
"farm-haystack[elasticsearch,faiss-gpu,weaviate,pinecone,opensearch]",
|
||||
]
|
||||
aws = [
|
||||
"boto3",
|
||||
# Costraint botocore to avoid taking to much time to resolve the dependency tree.
|
||||
# boto3 used to constraint it at this version more than a year ago. To avoid breaking
|
||||
# people using old versions we use a similar constraint without upper bound.
|
||||
# https://github.com/boto/boto3/blob/dae73bef223abbedfa7317a783070831febc0c90/setup.py#L16
|
||||
"botocore>=1.27",
|
||||
# first version to support Amazon Bedrock
|
||||
"boto3>=1.28.57",
|
||||
]
|
||||
crawler = [
|
||||
"selenium>=4.11.0"
|
||||
|
||||
16
releasenotes/notes/bedrock-support-bce28e3078c85c12.yaml
Normal file
16
releasenotes/notes/bedrock-support-bce28e3078c85c12.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
---
|
||||
prelude: >
|
||||
Haystack now supports Amazon Bedrock models, including all existing and previously announced
|
||||
models, like Llama-2-70b-chat. To use these models, simply pass the model ID in the
|
||||
model_name_or_path parameter, like you do for any other model. For details, see
|
||||
[Amazon Bedrock Docmentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html).
|
||||
|
||||
For example, the following code loads the Llama 2 Chat 13B model:
|
||||
```python
|
||||
from haystack.nodes import PromptNode
|
||||
|
||||
prompt_node = PromptNode(model_name_or_path="meta.llama2-13b-chat-v1")
|
||||
```
|
||||
features:
|
||||
- |
|
||||
You can use Amazon Bedrock models in Haystack.
|
||||
1033
test/prompt/invocation_layer/test_amazon_bedrock.py
Normal file
1033
test/prompt/invocation_layer/test_amazon_bedrock.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -15,5 +15,5 @@ def test_invocation_layer_order():
|
||||
assert HFInferenceEndpointInvocationLayer in invocation_layers
|
||||
index_hf = invocation_layers.index(HFLocalInvocationLayer) + 1
|
||||
index_hf_inference = invocation_layers.index(HFInferenceEndpointInvocationLayer) + 1
|
||||
assert index_hf > len(invocation_layers) / 2
|
||||
assert index_hf_inference > len(invocation_layers) / 2
|
||||
assert index_hf >= 7
|
||||
assert index_hf_inference >= 7
|
||||
|
||||
@ -223,7 +223,7 @@ def test_supports_for_valid_aws_configuration():
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerHFInferenceInvocationLayer.supports(
|
||||
@ -245,12 +245,10 @@ def test_supports_not_on_invalid_aws_profile_name():
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_boto3_session.side_effect = BotoCoreError()
|
||||
with pytest.raises(SageMakerConfigurationError) as exc_info:
|
||||
supported = SageMakerHFInferenceInvocationLayer.supports(
|
||||
with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"):
|
||||
SageMakerHFInferenceInvocationLayer.supports(
|
||||
model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile"
|
||||
)
|
||||
assert "Failed to initialize the session" in exc_info.value
|
||||
assert not supported
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@ -81,9 +81,8 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
|
||||
Test that invoke raises an error if no prompt is provided
|
||||
"""
|
||||
layer = SageMakerHFTextGenerationInvocationLayer(model_name_or_path="some_fake_model")
|
||||
with pytest.raises(ValueError) as e:
|
||||
with pytest.raises(ValueError, match="No prompt provided."):
|
||||
layer.invoke()
|
||||
assert e.match("No prompt provided.")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -223,7 +222,7 @@ def test_supports_for_valid_aws_configuration():
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerHFTextGenerationInvocationLayer.supports(
|
||||
@ -245,12 +244,10 @@ def test_supports_not_on_invalid_aws_profile_name():
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_boto3_session.side_effect = BotoCoreError()
|
||||
with pytest.raises(SageMakerConfigurationError) as exc_info:
|
||||
supported = SageMakerHFTextGenerationInvocationLayer.supports(
|
||||
with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"):
|
||||
SageMakerHFTextGenerationInvocationLayer.supports(
|
||||
model_name_or_path="some_fake_model", aws_profile_name="some_fake_profile"
|
||||
)
|
||||
assert "Failed to initialize the session" in exc_info.value
|
||||
assert not supported
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@ -97,9 +97,8 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
|
||||
Test invoke raises an error if no prompt is provided
|
||||
"""
|
||||
layer = SageMakerMetaInvocationLayer(model_name_or_path="some_fake_model")
|
||||
with pytest.raises(ValueError) as e:
|
||||
with pytest.raises(ValueError, match="No valid prompt provided."):
|
||||
layer.invoke()
|
||||
assert e.match("No prompt provided.")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -239,7 +238,7 @@ def test_supports_for_valid_aws_configuration():
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
@ -263,14 +262,12 @@ def test_supports_not_on_invalid_aws_profile_name():
|
||||
|
||||
with patch("boto3.Session") as mock_boto3_session:
|
||||
mock_boto3_session.side_effect = BotoCoreError()
|
||||
with pytest.raises(SageMakerConfigurationError) as exc_info:
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
with pytest.raises(SageMakerConfigurationError, match="Failed to initialize the session"):
|
||||
SageMakerMetaInvocationLayer.supports(
|
||||
model_name_or_path="some_fake_model",
|
||||
aws_profile_name="some_fake_profile",
|
||||
aws_custom_attributes={"accept_eula": True},
|
||||
)
|
||||
assert "Failed to initialize the session" in exc_info.value
|
||||
assert not supported
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@ -287,7 +284,7 @@ def test_supports_not_on_missing_eula():
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
@ -311,7 +308,7 @@ def test_supports_not_on_eula_not_accepted():
|
||||
|
||||
# Patch the class method to return the mock session
|
||||
with patch(
|
||||
"haystack.nodes.prompt.invocation_layer.sagemaker_base.SageMakerBaseInvocationLayer.create_session",
|
||||
"haystack.nodes.prompt.invocation_layer.aws_base.AWSBaseInvocationLayer.get_aws_session",
|
||||
return_value=mock_session,
|
||||
):
|
||||
supported = SageMakerMetaInvocationLayer.supports(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user