Resolve confusion between azure embedding and completion environment variables

This commit is contained in:
Alexander Bruhn 2025-06-04 11:27:59 +02:00
parent 5ec81d652d
commit 5e3970e18b
No known key found for this signature in database
2 changed files with 54 additions and 28 deletions

View File

@ -103,6 +103,8 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
### Optional for Azure ### Optional for Azure
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large # AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
# AZURE_EMBEDDING_API_VERSION=2023-05-15 # AZURE_EMBEDDING_API_VERSION=2023-05-15
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
# AZURE_EMBEDDING_API_KEY=your_api_key
### Data storage selection ### Data storage selection
# LIGHTRAG_KV_STORAGE=PGKVStorage # LIGHTRAG_KV_STORAGE=PGKVStorage

View File

@ -1,3 +1,4 @@
from collections.abc import Iterable
import os import os
import pipmaster as pm # Pipmaster for dynamic library install import pipmaster as pm # Pipmaster for dynamic library install
@ -11,6 +12,8 @@ from openai import (
RateLimitError, RateLimitError,
APITimeoutError, APITimeoutError,
) )
from openai.types.chat import ChatCompletionMessageParam
from tenacity import ( from tenacity import (
retry, retry,
stop_after_attempt, stop_after_attempt,
@ -37,31 +40,38 @@ import numpy as np
async def azure_openai_complete_if_cache( async def azure_openai_complete_if_cache(
model, model,
prompt, prompt,
system_prompt=None, system_prompt: str | None = None,
history_messages=[], history_messages: Iterable[ChatCompletionMessageParam] | None = None,
base_url=None, base_url: str | None = None,
api_key=None, api_key: str | None = None,
api_version=None, api_version: str | None = None,
**kwargs, **kwargs,
): ):
if api_key: model = model or os.getenv("AZURE_OPENAI_DEPLOYMENT") or os.getenv("LLM_MODEL")
os.environ["AZURE_OPENAI_API_KEY"] = api_key base_url = (
if base_url: base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url )
if api_version: api_key = (
os.environ["AZURE_OPENAI_API_VERSION"] = api_version api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
openai_async_client = AsyncAzureOpenAI( openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=base_url,
azure_deployment=model, azure_deployment=model,
api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_key=api_key,
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), api_version=api_version,
) )
kwargs.pop("hashing_kv", None) kwargs.pop("hashing_kv", None)
messages = [] messages = []
if system_prompt: if system_prompt:
messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages) if history_messages:
messages.extend(history_messages)
if prompt is not None: if prompt is not None:
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
@ -121,23 +131,37 @@ async def azure_openai_complete(
) )
async def azure_openai_embed( async def azure_openai_embed(
texts: list[str], texts: list[str],
model: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"), model: str | None = None,
base_url: str = None, base_url: str | None = None,
api_key: str = None, api_key: str | None = None,
api_version: str = None, api_version: str | None = None,
) -> np.ndarray: ) -> np.ndarray:
if api_key: model = (
os.environ["AZURE_OPENAI_API_KEY"] = api_key model
if base_url: or os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
if api_version: )
os.environ["AZURE_OPENAI_API_VERSION"] = api_version base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
openai_async_client = AsyncAzureOpenAI( openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=base_url,
azure_deployment=model, azure_deployment=model,
api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_key=api_key,
api_version=os.getenv("AZURE_OPENAI_API_VERSION"), api_version=api_version,
) )
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(