Fix Mixed LLM settings for completions and embeddings (#517)

* Fixed an issue where base OpenAI embeddings can't work with Azure OpenAI LLM

* Removed redundant None from else

* Format

---------

Co-authored-by: Kenny Stryker <nggkenny@gmail.com>
This commit is contained in:
Alonso Guevara 2024-07-11 17:09:04 -06:00 committed by GitHub
parent af74cae191
commit 53fcafd57f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 5 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Fixed an issue where base OpenAI embeddings can't work with Azure OpenAI LLM"
}

View File

@ -137,13 +137,27 @@ def create_graphrag_config(
config: LLMConfigInput, base: LLMParameters
) -> LLMParameters:
with reader.use(config.get("llm")):
api_key = reader.str(Fragment.api_key) or base.api_key
api_base = reader.str(Fragment.api_base) or base.api_base
api_version = reader.str(Fragment.api_version) or base.api_version
api_organization = reader.str("organization") or base.organization
api_proxy = reader.str("proxy") or base.proxy
api_type = reader.str(Fragment.type) or defs.EMBEDDING_TYPE
api_type = LLMType(api_type) if api_type else defs.LLM_TYPE
api_key = reader.str(Fragment.api_key) or base.api_key
# In a unique events where:
# - same api_bases for LLM and embeddings (both Azure)
# - different api_bases for LLM and embeddings (both Azure)
# - LLM uses Azure OpenAI, while embeddings uses base OpenAI (this one is important)
# - LLM uses Azure OpenAI, while embeddings uses third-party OpenAI-like API
api_base = (
reader.str(Fragment.api_base) or base.api_base
if _is_azure(api_type)
else reader.str(Fragment.api_base)
)
api_version = (
reader.str(Fragment.api_version) or base.api_version
if _is_azure(api_type)
else reader.str(Fragment.api_version)
)
api_organization = reader.str("organization") or base.organization
api_proxy = reader.str("proxy") or base.proxy
cognitive_services_endpoint = (
reader.str(Fragment.cognitive_services_endpoint)
or base.cognitive_services_endpoint