Resolve confusion between azure embedding and completion environment variables

This commit is contained in:
Alexander Bruhn 2025-06-04 11:27:59 +02:00
parent 5ec81d652d
commit 5e3970e18b
No known key found for this signature in database
2 changed files with 54 additions and 28 deletions

View File

@ -103,6 +103,8 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
### Optional for Azure
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
# AZURE_EMBEDDING_API_VERSION=2023-05-15
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
# AZURE_EMBEDDING_API_KEY=your_api_key
### Data storage selection
# LIGHTRAG_KV_STORAGE=PGKVStorage

View File

@ -1,3 +1,4 @@
from collections.abc import Iterable
import os
import pipmaster as pm # Pipmaster for dynamic library install
@ -11,6 +12,8 @@ from openai import (
RateLimitError,
APITimeoutError,
)
from openai.types.chat import ChatCompletionMessageParam
from tenacity import (
retry,
stop_after_attempt,
@ -37,31 +40,38 @@ import numpy as np
async def azure_openai_complete_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url=None,
api_key=None,
api_version=None,
system_prompt: str | None = None,
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
**kwargs,
):
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
model = model or os.getenv("AZURE_OPENAI_DEPLOYMENT") or os.getenv("LLM_MODEL")
base_url = (
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
)
api_key = (
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_OPENAI_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_endpoint=base_url,
azure_deployment=model,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
api_key=api_key,
api_version=api_version,
)
kwargs.pop("hashing_kv", None)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
if history_messages:
messages.extend(history_messages)
if prompt is not None:
messages.append({"role": "user", "content": prompt})
@ -121,23 +131,37 @@ async def azure_openai_complete(
)
async def azure_openai_embed(
texts: list[str],
model: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
base_url: str = None,
api_key: str = None,
api_version: str = None,
model: str | None = None,
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
) -> np.ndarray:
if api_key:
os.environ["AZURE_OPENAI_API_KEY"] = api_key
if base_url:
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
if api_version:
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
model = (
model
or os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
)
base_url = (
base_url
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
or os.getenv("EMBEDDING_BINDING_HOST")
)
api_key = (
api_key
or os.getenv("AZURE_EMBEDDING_API_KEY")
or os.getenv("EMBEDDING_BINDING_API_KEY")
)
api_version = (
api_version
or os.getenv("AZURE_EMBEDDING_API_VERSION")
or os.getenv("OPENAI_API_VERSION")
)
openai_async_client = AsyncAzureOpenAI(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_endpoint=base_url,
azure_deployment=model,
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
api_key=api_key,
api_version=api_version,
)
response = await openai_async_client.embeddings.create(