mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-06-26 22:00:19 +00:00
Resolve confusion between azure embedding and completion environment variables
This commit is contained in:
parent
5ec81d652d
commit
5e3970e18b
@ -103,6 +103,8 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||
### Optional for Azure
|
||||
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
|
||||
# AZURE_EMBEDDING_API_VERSION=2023-05-15
|
||||
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
||||
# AZURE_EMBEDDING_API_KEY=your_api_key
|
||||
|
||||
### Data storage selection
|
||||
# LIGHTRAG_KV_STORAGE=PGKVStorage
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Iterable
|
||||
import os
|
||||
import pipmaster as pm # Pipmaster for dynamic library install
|
||||
|
||||
@ -11,6 +12,8 @@ from openai import (
|
||||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
@ -37,31 +40,38 @@ import numpy as np
|
||||
async def azure_openai_complete_if_cache(
|
||||
model,
|
||||
prompt,
|
||||
system_prompt=None,
|
||||
history_messages=[],
|
||||
base_url=None,
|
||||
api_key=None,
|
||||
api_version=None,
|
||||
system_prompt: str | None = None,
|
||||
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_version: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
if api_version:
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
||||
model = model or os.getenv("AZURE_OPENAI_DEPLOYMENT") or os.getenv("LLM_MODEL")
|
||||
base_url = (
|
||||
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
|
||||
)
|
||||
api_key = (
|
||||
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
|
||||
)
|
||||
api_version = (
|
||||
api_version
|
||||
or os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
or os.getenv("OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
azure_endpoint=base_url,
|
||||
azure_deployment=model,
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
)
|
||||
kwargs.pop("hashing_kv", None)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.extend(history_messages)
|
||||
if history_messages:
|
||||
messages.extend(history_messages)
|
||||
if prompt is not None:
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
@ -121,23 +131,37 @@ async def azure_openai_complete(
|
||||
)
|
||||
async def azure_openai_embed(
|
||||
texts: list[str],
|
||||
model: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"),
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
api_version: str = None,
|
||||
model: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> np.ndarray:
|
||||
if api_key:
|
||||
os.environ["AZURE_OPENAI_API_KEY"] = api_key
|
||||
if base_url:
|
||||
os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
|
||||
if api_version:
|
||||
os.environ["AZURE_OPENAI_API_VERSION"] = api_version
|
||||
model = (
|
||||
model
|
||||
or os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
|
||||
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
||||
)
|
||||
base_url = (
|
||||
base_url
|
||||
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
|
||||
or os.getenv("EMBEDDING_BINDING_HOST")
|
||||
)
|
||||
api_key = (
|
||||
api_key
|
||||
or os.getenv("AZURE_EMBEDDING_API_KEY")
|
||||
or os.getenv("EMBEDDING_BINDING_API_KEY")
|
||||
)
|
||||
api_version = (
|
||||
api_version
|
||||
or os.getenv("AZURE_EMBEDDING_API_VERSION")
|
||||
or os.getenv("OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
openai_async_client = AsyncAzureOpenAI(
|
||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||
azure_endpoint=base_url,
|
||||
azure_deployment=model,
|
||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
response = await openai_async_client.embeddings.create(
|
||||
|
Loading…
x
Reference in New Issue
Block a user