From 4f2816c01c49629ce1b9ba891e34b7a09853abdd Mon Sep 17 00:00:00 2001 From: Omar Leonardo Sanchez Granados Date: Sun, 23 Feb 2025 22:01:14 -0500 Subject: [PATCH] Add support to boto3 default connection (#5246) ### What problem does this PR solve? This pull request includes changes to the initialization logic of the `ChatModel` and `EmbeddingModel` classes to enhance the handling of AWS credentials. Use cases: - Use env variables for credentials instead of managing them on the DB - Easy connection when deploying on an AWS machine ### Type of change - [X] New Feature (non-breaking change which adds functionality) --- rag/llm/chat_model.py | 7 ++++++- rag/llm/embedding_model.py | 9 +++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 9286d369..08a4f3e5 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -703,7 +703,12 @@ class BedrockChat(Base): self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_region = json.loads(key).get('bedrock_region', '') self.model_name = model_name - self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, + + if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': + # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) + self.client = boto3.client('bedrock-runtime') + else: + self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) def chat(self, system, history, gen_conf): diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 893bf65e..17bb84ef 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -476,8 +476,13 @@ class BedrockEmbed(Base): self.bedrock_sk = json.loads(key).get('bedrock_sk', '') self.bedrock_region = json.loads(key).get('bedrock_region', '') self.model_name = model_name - self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, - aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + + if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': + # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) + self.client = boto3.client('bedrock-runtime') + else: + self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, + aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) def encode(self, texts: list): texts = [truncate(t, 8196) for t in texts]