diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9286d369c..08a4f3e5b 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -703,7 +703,12 @@ class BedrockChat(Base): self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_region = json.loads(key).get('bedrock_region', '') self.model_name = model_name - self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, + + if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': + # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) + self.client = boto3.client('bedrock-runtime') + else: + self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) def chat(self, system, history, gen_conf): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 893bf65ef..17bb84ef5 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -476,8 +476,13 @@ class BedrockEmbed(Base): self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_region = json.loads(key).get('bedrock_region', '') self.model_name = model_name - self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, - aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + + if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': + # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) + self.client = boto3.client('bedrock-runtime') + else: + self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, + aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts]