LightRAG/lightrag/llm/bedrock.py
2025-08-15 22:21:34 +00:00

322 lines
11 KiB
Python

import copy
import os
import json
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aioboto3"):
pm.install("aioboto3")
import aioboto3
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
from typing import Union
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
def _set_env_if_present(key: str, value):
"""Set environment variable only if a non-empty value is provided."""
if value is not None and value != "":
os.environ[key] = value
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
# Region handling: prefer env, else kwarg (optional)
region = os.environ.get("AWS_REGION") or kwargs.pop("aws_region", None)
kwargs.pop("hashing_kv", None)
# Capture stream flag (if provided) and remove from kwargs since it's not a Bedrock API parameter
# We'll use this to determine whether to call converse_stream or converse
stream = bool(kwargs.pop("stream", False))
# Remove unsupported args for Bedrock Converse API
for k in [
"response_format",
"tools",
"tool_choice",
"seed",
"presence_penalty",
"frequency_penalty",
"n",
"logprobs",
"top_logprobs",
"max_completion_tokens",
"response_format",
]:
kwargs.pop(k, None)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Import logging for error handling
import logging
# For streaming responses, we need a different approach to keep the connection open
if stream:
# Create a session that will be used throughout the streaming process
session = aioboto3.Session()
client = None
# Define the generator function that will manage the client lifecycle
async def stream_generator():
nonlocal client
# Create the client outside the generator to ensure it stays open
client = await session.client(
"bedrock-runtime", region_name=region
).__aenter__()
event_stream = None
iteration_started = False
try:
# Make the API call
response = await client.converse_stream(**args, **kwargs)
event_stream = response.get("stream")
iteration_started = True
# Process the stream
async for event in event_stream:
# Validate event structure
if not event or not isinstance(event, dict):
continue
if "contentBlockDelta" in event:
delta = event["contentBlockDelta"].get("delta", {})
text = delta.get("text")
if text:
yield text
# Handle other event types that might indicate stream end
elif "messageStop" in event:
break
except Exception as e:
# Log the specific error for debugging
logging.error(f"Bedrock streaming error: {e}")
# Try to clean up resources if possible
if (
iteration_started
and event_stream
and hasattr(event_stream, "aclose")
and callable(getattr(event_stream, "aclose", None))
):
try:
await event_stream.aclose()
except Exception as close_error:
logging.warning(
f"Failed to close Bedrock event stream: {close_error}"
)
raise BedrockError(f"Streaming error: {e}")
finally:
# Clean up the event stream
if (
iteration_started
and event_stream
and hasattr(event_stream, "aclose")
and callable(getattr(event_stream, "aclose", None))
):
try:
await event_stream.aclose()
except Exception as close_error:
logging.warning(
f"Failed to close Bedrock event stream in finally block: {close_error}"
)
# Clean up the client
if client:
try:
await client.__aexit__(None, None, None)
except Exception as client_close_error:
logging.warning(
f"Failed to close Bedrock client: {client_close_error}"
)
# Return the generator that manages its own lifecycle
return stream_generator()
# For non-streaming responses, use the standard async context manager pattern
session = aioboto3.Session()
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
try:
# Use converse for non-streaming responses
response = await bedrock_async_client.converse(**args, **kwargs)
# Validate response structure
if (
not response
or "output" not in response
or "message" not in response["output"]
or "content" not in response["output"]["message"]
or not response["output"]["message"]["content"]
):
raise BedrockError("Invalid response structure from Bedrock API")
content = response["output"]["message"]["content"][0]["text"]
if not content or content.strip() == "":
raise BedrockError("Received empty content from Bedrock API")
return content
except Exception as e:
if isinstance(e, BedrockError):
raise
else:
raise BedrockError(f"Bedrock API error: {e}")
# Generic Bedrock completion function
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await bedrock_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
return result
# @wrap_embedding_func_with_attrs(embedding_dim=1024)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") or aws_secret_access_key
session_token = os.environ.get("AWS_SESSION_TOKEN") or aws_session_token
_set_env_if_present("AWS_ACCESS_KEY_ID", access_key)
_set_env_if_present("AWS_SECRET_ACCESS_KEY", secret_key)
_set_env_if_present("AWS_SESSION_TOKEN", session_token)
# Region handling: prefer env
region = os.environ.get("AWS_REGION")
session = aioboto3.Session()
async with session.client(
"bedrock-runtime", region_name=region
) as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)