From 5e3970e18ba0751968a4ad7172d57e84171d5bf6 Mon Sep 17 00:00:00 2001 From: Alexander Bruhn Date: Wed, 4 Jun 2025 11:27:59 +0200 Subject: [PATCH] Resolve confusion between azure embedding and completion environment variables --- env.example | 2 + lightrag/llm/azure_openai.py | 80 +++++++++++++++++++++++------------- 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/env.example b/env.example index 30faaeac..f98648d2 100644 --- a/env.example +++ b/env.example @@ -103,6 +103,8 @@ EMBEDDING_BINDING_HOST=http://localhost:11434 ### Optional for Azure # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large # AZURE_EMBEDDING_API_VERSION=2023-05-15 +# AZURE_EMBEDDING_ENDPOINT=your_endpoint +# AZURE_EMBEDDING_API_KEY=your_api_key ### Data storage selection # LIGHTRAG_KV_STORAGE=PGKVStorage diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 1963b4fe..6354ef18 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import os import pipmaster as pm # Pipmaster for dynamic library install @@ -11,6 +12,8 @@ from openai import ( RateLimitError, APITimeoutError, ) +from openai.types.chat import ChatCompletionMessageParam + from tenacity import ( retry, stop_after_attempt, @@ -37,31 +40,38 @@ import numpy as np async def azure_openai_complete_if_cache( model, prompt, - system_prompt=None, - history_messages=[], - base_url=None, - api_key=None, - api_version=None, + system_prompt: str | None = None, + history_messages: Iterable[ChatCompletionMessageParam] | None = None, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, **kwargs, ): - if api_key: - os.environ["AZURE_OPENAI_API_KEY"] = api_key - if base_url: - os.environ["AZURE_OPENAI_ENDPOINT"] = base_url - if api_version: - os.environ["AZURE_OPENAI_API_VERSION"] = api_version + model = model or os.getenv("AZURE_OPENAI_DEPLOYMENT") or os.getenv("LLM_MODEL") + base_url = ( + base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST") + ) + api_key = ( + api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_OPENAI_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) openai_async_client = AsyncAzureOpenAI( - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=base_url, azure_deployment=model, - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + api_key=api_key, + api_version=api_version, ) kwargs.pop("hashing_kv", None) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) - messages.extend(history_messages) + if history_messages: + messages.extend(history_messages) if prompt is not None: messages.append({"role": "user", "content": prompt}) @@ -121,23 +131,37 @@ async def azure_openai_complete( ) async def azure_openai_embed( texts: list[str], - model: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"), - base_url: str = None, - api_key: str = None, - api_version: str = None, + model: str | None = None, + base_url: str | None = None, + api_key: str | None = None, + api_version: str | None = None, ) -> np.ndarray: - if api_key: - os.environ["AZURE_OPENAI_API_KEY"] = api_key - if base_url: - os.environ["AZURE_OPENAI_ENDPOINT"] = base_url - if api_version: - os.environ["AZURE_OPENAI_API_VERSION"] = api_version + model = ( + model + or os.getenv("AZURE_EMBEDDING_DEPLOYMENT") + or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + ) + base_url = ( + base_url + or os.getenv("AZURE_EMBEDDING_ENDPOINT") + or os.getenv("EMBEDDING_BINDING_HOST") + ) + api_key = ( + api_key + or os.getenv("AZURE_EMBEDDING_API_KEY") + or os.getenv("EMBEDDING_BINDING_API_KEY") + ) + api_version = ( + api_version + or os.getenv("AZURE_EMBEDDING_API_VERSION") + or os.getenv("OPENAI_API_VERSION") + ) openai_async_client = AsyncAzureOpenAI( - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=base_url, azure_deployment=model, - api_key=os.getenv("AZURE_OPENAI_API_KEY"), - api_version=os.getenv("AZURE_OPENAI_API_VERSION"), + api_key=api_key, + api_version=api_version, ) response = await openai_async_client.embeddings.create(