LightRAG/lightrag/llm/bedrock.py
2025-05-14 10:57:05 +08:00

183 lines
5.7 KiB
Python

import copy
import os
import json
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aioboto3"):
pm.install("aioboto3")
import aioboto3
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
locate_json_string_body_from_string,
)
class BedrockError(Exception):
"""Generic error for issues related to Amazon Bedrock"""
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, max=60),
retry=retry_if_exception_type((BedrockError)),
)
async def bedrock_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
**kwargs,
) -> str:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
kwargs.pop("hashing_kv", None)
# Fix message history format
messages = []
for history_message in history_messages:
message = copy.copy(history_message)
message["content"] = [{"text": message["content"]}]
messages.append(message)
# Add user prompt
messages.append({"role": "user", "content": [{"text": prompt}]})
# Initialize Converse API arguments
args = {"modelId": model, "messages": messages}
# Define system prompt
if system_prompt:
args["system"] = [{"text": system_prompt}]
# Map and set up inference parameters
inference_params_map = {
"max_tokens": "maxTokens",
"top_p": "topP",
"stop_sequences": "stopSequences",
}
if inference_params := list(
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
):
args["inferenceConfig"] = {}
for param in inference_params:
args["inferenceConfig"][inference_params_map.get(param, param)] = (
kwargs.pop(param)
)
# Call model via Converse API
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
try:
response = await bedrock_async_client.converse(**args, **kwargs)
except Exception as e:
raise BedrockError(e)
return response["output"]["message"]["content"][0]["text"]
# Generic Bedrock completion function
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
keyword_extraction = kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
result = await bedrock_complete_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
if keyword_extraction: # TODO: use JSON API
return locate_json_string_body_from_string(result)
return result
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
# @retry(
# stop=stop_after_attempt(3),
# wait=wait_exponential(multiplier=1, min=4, max=10),
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
# )
async def bedrock_embed(
texts: list[str],
model: str = "amazon.titan-embed-text-v2:0",
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
) -> np.ndarray:
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
"AWS_ACCESS_KEY_ID", aws_access_key_id
)
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
)
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
"AWS_SESSION_TOKEN", aws_session_token
)
session = aioboto3.Session()
async with session.client("bedrock-runtime") as bedrock_async_client:
if (model_provider := model.split(".")[0]) == "amazon":
embed_texts = []
for text in texts:
if "v2" in model:
body = json.dumps(
{
"inputText": text,
# 'dimensions': embedding_dim,
"embeddingTypes": ["float"],
}
)
elif "v1" in model:
body = json.dumps({"inputText": text})
else:
raise ValueError(f"Model {model} is not supported!")
response = await bedrock_async_client.invoke_model(
modelId=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = await response.get("body").json()
embed_texts.append(response_body["embedding"])
elif model_provider == "cohere":
body = json.dumps(
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
)
response = await bedrock_async_client.invoke_model(
model=model,
body=body,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embed_texts = response_body["embeddings"]
else:
raise ValueError(f"Model provider '{model_provider}' is not supported!")
return np.array(embed_texts)