mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-10-30 17:29:34 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			226 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			226 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Bedrock LLM Interface Module
 | |
| ==========================
 | |
| 
 | |
| This module provides interfaces for interacting with Bedrock's language models,
 | |
| including text generation and embedding capabilities.
 | |
| 
 | |
| Author: Lightrag team
 | |
| Created: 2024-01-24
 | |
| License: MIT License
 | |
| 
 | |
| Copyright (c) 2024 Lightrag
 | |
| 
 | |
| Permission is hereby granted, free of charge, to any person obtaining a copy
 | |
| of this software and associated documentation files (the "Software"), to deal
 | |
| in the Software without restriction, including without limitation the rights
 | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | |
| copies of the Software, and to permit persons to whom the Software is
 | |
| furnished to do so, subject to the following conditions:
 | |
| 
 | |
| Version: 1.0.0
 | |
| 
 | |
| Change Log:
 | |
| - 1.0.0 (2024-01-24): Initial release
 | |
|     * Added async chat completion support
 | |
|     * Added embedding generation
 | |
|     * Added stream response capability
 | |
| 
 | |
| Dependencies:
 | |
|     - aioboto3, tenacity
 | |
|     - numpy
 | |
|     - pipmaster
 | |
|     - Python >= 3.10
 | |
| 
 | |
| Usage:
 | |
|     from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed
 | |
| """
 | |
| 
 | |
| __version__ = "1.0.0"
 | |
| __author__ = "lightrag Team"
 | |
| __status__ = "Production"
 | |
| 
 | |
| 
 | |
| import copy
 | |
| import os
 | |
| import json
 | |
| 
 | |
| import pipmaster as pm  # Pipmaster for dynamic library install
 | |
| 
 | |
| if not pm.is_installed("aioboto3"):
 | |
|     pm.install("aioboto3")
 | |
| if not pm.is_installed("tenacity"):
 | |
|     pm.install("tenacity")
 | |
| import aioboto3
 | |
| import numpy as np
 | |
| from tenacity import (
 | |
|     retry,
 | |
|     stop_after_attempt,
 | |
|     wait_exponential,
 | |
|     retry_if_exception_type,
 | |
| )
 | |
| 
 | |
| from lightrag.utils import (
 | |
|     locate_json_string_body_from_string,
 | |
| )
 | |
| 
 | |
| 
 | |
| class BedrockError(Exception):
 | |
|     """Generic error for issues related to Amazon Bedrock"""
 | |
| 
 | |
| 
 | |
| @retry(
 | |
|     stop=stop_after_attempt(5),
 | |
|     wait=wait_exponential(multiplier=1, max=60),
 | |
|     retry=retry_if_exception_type((BedrockError)),
 | |
| )
 | |
| async def bedrock_complete_if_cache(
 | |
|     model,
 | |
|     prompt,
 | |
|     system_prompt=None,
 | |
|     history_messages=[],
 | |
|     aws_access_key_id=None,
 | |
|     aws_secret_access_key=None,
 | |
|     aws_session_token=None,
 | |
|     **kwargs,
 | |
| ) -> str:
 | |
|     os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
 | |
|         "AWS_ACCESS_KEY_ID", aws_access_key_id
 | |
|     )
 | |
|     os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
 | |
|         "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
 | |
|     )
 | |
|     os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
 | |
|         "AWS_SESSION_TOKEN", aws_session_token
 | |
|     )
 | |
|     kwargs.pop("hashing_kv", None)
 | |
|     # Fix message history format
 | |
|     messages = []
 | |
|     for history_message in history_messages:
 | |
|         message = copy.copy(history_message)
 | |
|         message["content"] = [{"text": message["content"]}]
 | |
|         messages.append(message)
 | |
| 
 | |
|     # Add user prompt
 | |
|     messages.append({"role": "user", "content": [{"text": prompt}]})
 | |
| 
 | |
|     # Initialize Converse API arguments
 | |
|     args = {"modelId": model, "messages": messages}
 | |
| 
 | |
|     # Define system prompt
 | |
|     if system_prompt:
 | |
|         args["system"] = [{"text": system_prompt}]
 | |
| 
 | |
|     # Map and set up inference parameters
 | |
|     inference_params_map = {
 | |
|         "max_tokens": "maxTokens",
 | |
|         "top_p": "topP",
 | |
|         "stop_sequences": "stopSequences",
 | |
|     }
 | |
|     if inference_params := list(
 | |
|         set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
 | |
|     ):
 | |
|         args["inferenceConfig"] = {}
 | |
|         for param in inference_params:
 | |
|             args["inferenceConfig"][inference_params_map.get(param, param)] = (
 | |
|                 kwargs.pop(param)
 | |
|             )
 | |
| 
 | |
|     # Call model via Converse API
 | |
|     session = aioboto3.Session()
 | |
|     async with session.client("bedrock-runtime") as bedrock_async_client:
 | |
|         try:
 | |
|             response = await bedrock_async_client.converse(**args, **kwargs)
 | |
|         except Exception as e:
 | |
|             raise BedrockError(e)
 | |
| 
 | |
|     return response["output"]["message"]["content"][0]["text"]
 | |
| 
 | |
| 
 | |
| async def bedrock_complete(
 | |
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
 | |
| ) -> str:
 | |
|     keyword_extraction = kwargs.pop("keyword_extraction", None)
 | |
|     result = await bedrock_complete_if_cache(
 | |
|         "anthropic.claude-3-haiku-20240307-v1:0",
 | |
|         prompt,
 | |
|         system_prompt=system_prompt,
 | |
|         history_messages=history_messages,
 | |
|         **kwargs,
 | |
|     )
 | |
|     if keyword_extraction:  # TODO: use JSON API
 | |
|         return locate_json_string_body_from_string(result)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
 | |
| # @retry(
 | |
| #     stop=stop_after_attempt(3),
 | |
| #     wait=wait_exponential(multiplier=1, min=4, max=10),
 | |
| #     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),  # TODO: fix exceptions
 | |
| # )
 | |
| async def bedrock_embed(
 | |
|     texts: list[str],
 | |
|     model: str = "amazon.titan-embed-text-v2:0",
 | |
|     aws_access_key_id=None,
 | |
|     aws_secret_access_key=None,
 | |
|     aws_session_token=None,
 | |
| ) -> np.ndarray:
 | |
|     os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
 | |
|         "AWS_ACCESS_KEY_ID", aws_access_key_id
 | |
|     )
 | |
|     os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
 | |
|         "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
 | |
|     )
 | |
|     os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
 | |
|         "AWS_SESSION_TOKEN", aws_session_token
 | |
|     )
 | |
| 
 | |
|     session = aioboto3.Session()
 | |
|     async with session.client("bedrock-runtime") as bedrock_async_client:
 | |
|         if (model_provider := model.split(".")[0]) == "amazon":
 | |
|             embed_texts = []
 | |
|             for text in texts:
 | |
|                 if "v2" in model:
 | |
|                     body = json.dumps(
 | |
|                         {
 | |
|                             "inputText": text,
 | |
|                             # 'dimensions': embedding_dim,
 | |
|                             "embeddingTypes": ["float"],
 | |
|                         }
 | |
|                     )
 | |
|                 elif "v1" in model:
 | |
|                     body = json.dumps({"inputText": text})
 | |
|                 else:
 | |
|                     raise ValueError(f"Model {model} is not supported!")
 | |
| 
 | |
|                 response = await bedrock_async_client.invoke_model(
 | |
|                     modelId=model,
 | |
|                     body=body,
 | |
|                     accept="application/json",
 | |
|                     contentType="application/json",
 | |
|                 )
 | |
| 
 | |
|                 response_body = await response.get("body").json()
 | |
| 
 | |
|                 embed_texts.append(response_body["embedding"])
 | |
|         elif model_provider == "cohere":
 | |
|             body = json.dumps(
 | |
|                 {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
 | |
|             )
 | |
| 
 | |
|             response = await bedrock_async_client.invoke_model(
 | |
|                 model=model,
 | |
|                 body=body,
 | |
|                 accept="application/json",
 | |
|                 contentType="application/json",
 | |
|             )
 | |
| 
 | |
|             response_body = json.loads(response.get("body").read())
 | |
| 
 | |
|             embed_texts = response_body["embeddings"]
 | |
|         else:
 | |
|             raise ValueError(f"Model provider '{model_provider}' is not supported!")
 | |
| 
 | |
|         return np.array(embed_texts)
 | 
