From 53fcafd57fe10a69f3dd7b5d46a5786ee32e6ff9 Mon Sep 17 00:00:00 2001 From: Alonso Guevara Date: Thu, 11 Jul 2024 17:09:04 -0600 Subject: [PATCH] Fix Mixed LLM settings for completions and embeddings (#517) * Fixed an issue where base OpenAI embeddings can't work with Azure OpenAI LLM * Removed redundant None from else * Format --------- Co-authored-by: Kenny Stryker --- .../patch-20240710165603516866.json | 4 ++++ graphrag/config/create_graphrag_config.py | 24 +++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 .semversioner/next-release/patch-20240710165603516866.json diff --git a/.semversioner/next-release/patch-20240710165603516866.json b/.semversioner/next-release/patch-20240710165603516866.json new file mode 100644 index 00000000..b5066cf1 --- /dev/null +++ b/.semversioner/next-release/patch-20240710165603516866.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fixed an issue where base OpenAI embeddings can't work with Azure OpenAI LLM" +} diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index c2acd7b1..45edfc84 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -137,13 +137,27 @@ def create_graphrag_config( config: LLMConfigInput, base: LLMParameters ) -> LLMParameters: with reader.use(config.get("llm")): - api_key = reader.str(Fragment.api_key) or base.api_key - api_base = reader.str(Fragment.api_base) or base.api_base - api_version = reader.str(Fragment.api_version) or base.api_version - api_organization = reader.str("organization") or base.organization - api_proxy = reader.str("proxy") or base.proxy api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE api_type = LLMType(api_type) if api_type else defs.LLM_TYPE + api_key = reader.str(Fragment.api_key) or base.api_key + + # In a unique events where: + # - same api_bases for LLM and embeddings (both Azure) + # - different api_bases for LLM and embeddings (both Azure) + # - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important) + # - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API + api_base = ( + reader.str(Fragment.api_base) or base.api_base + if _is_azure(api_type) + else reader.str(Fragment.api_base) + ) + api_version = ( + reader.str(Fragment.api_version) or base.api_version + if _is_azure(api_type) + else reader.str(Fragment.api_version) + ) + api_organization = reader.str("organization") or base.organization + api_proxy = reader.str("proxy") or base.proxy cognitive_services_endpoint = ( reader.str(Fragment.cognitive_services_endpoint) or base.cognitive_services_endpoint