mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-22 13:06:10 +00:00
This parameter is no longer used. Its removal simplifies the API and clarifies that token length management is handled by upstream text chunking logic rather than the embedding wrapper.
171 lines
4.8 KiB
Python
171 lines
4.8 KiB
Python
from collections.abc import Iterable
|
|
import os
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
|
|
|
# install specific modules
|
|
if not pm.is_installed("openai"):
|
|
pm.install("openai")
|
|
|
|
from openai import (
|
|
AsyncAzureOpenAI,
|
|
APIConnectionError,
|
|
RateLimitError,
|
|
APITimeoutError,
|
|
)
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
retry_if_exception_type,
|
|
)
|
|
|
|
from lightrag.utils import (
|
|
wrap_embedding_func_with_attrs,
|
|
locate_json_string_body_from_string,
|
|
safe_unicode_decode,
|
|
)
|
|
|
|
import numpy as np
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
retry=retry_if_exception_type(
|
|
(RateLimitError, APIConnectionError, APIConnectionError)
|
|
),
|
|
)
|
|
async def azure_openai_complete_if_cache(
|
|
model,
|
|
prompt,
|
|
system_prompt: str | None = None,
|
|
history_messages: Iterable[ChatCompletionMessageParam] | None = None,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
api_version: str | None = None,
|
|
**kwargs,
|
|
):
|
|
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") or model or os.getenv("LLM_MODEL")
|
|
base_url = (
|
|
base_url or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("LLM_BINDING_HOST")
|
|
)
|
|
api_key = (
|
|
api_key or os.getenv("AZURE_OPENAI_API_KEY") or os.getenv("LLM_BINDING_API_KEY")
|
|
)
|
|
api_version = (
|
|
api_version
|
|
or os.getenv("AZURE_OPENAI_API_VERSION")
|
|
or os.getenv("OPENAI_API_VERSION")
|
|
)
|
|
|
|
openai_async_client = AsyncAzureOpenAI(
|
|
azure_endpoint=base_url,
|
|
azure_deployment=deployment,
|
|
api_key=api_key,
|
|
api_version=api_version,
|
|
)
|
|
kwargs.pop("hashing_kv", None)
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
if history_messages:
|
|
messages.extend(history_messages)
|
|
if prompt is not None:
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
if "response_format" in kwargs:
|
|
response = await openai_async_client.beta.chat.completions.parse(
|
|
model=model, messages=messages, **kwargs
|
|
)
|
|
else:
|
|
response = await openai_async_client.chat.completions.create(
|
|
model=model, messages=messages, **kwargs
|
|
)
|
|
|
|
if hasattr(response, "__aiter__"):
|
|
|
|
async def inner():
|
|
async for chunk in response:
|
|
if len(chunk.choices) == 0:
|
|
continue
|
|
content = chunk.choices[0].delta.content
|
|
if content is None:
|
|
continue
|
|
if r"\u" in content:
|
|
content = safe_unicode_decode(content.encode("utf-8"))
|
|
yield content
|
|
|
|
return inner()
|
|
else:
|
|
content = response.choices[0].message.content
|
|
if r"\u" in content:
|
|
content = safe_unicode_decode(content.encode("utf-8"))
|
|
return content
|
|
|
|
|
|
async def azure_openai_complete(
|
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
|
) -> str:
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
result = await azure_openai_complete_if_cache(
|
|
os.getenv("LLM_MODEL", "gpt-4o-mini"),
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
if keyword_extraction: # TODO: use JSON API
|
|
return locate_json_string_body_from_string(result)
|
|
return result
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
retry=retry_if_exception_type(
|
|
(RateLimitError, APIConnectionError, APITimeoutError)
|
|
),
|
|
)
|
|
async def azure_openai_embed(
|
|
texts: list[str],
|
|
model: str | None = None,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
api_version: str | None = None,
|
|
) -> np.ndarray:
|
|
deployment = (
|
|
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
|
|
or model
|
|
or os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
)
|
|
base_url = (
|
|
base_url
|
|
or os.getenv("AZURE_EMBEDDING_ENDPOINT")
|
|
or os.getenv("EMBEDDING_BINDING_HOST")
|
|
)
|
|
api_key = (
|
|
api_key
|
|
or os.getenv("AZURE_EMBEDDING_API_KEY")
|
|
or os.getenv("EMBEDDING_BINDING_API_KEY")
|
|
)
|
|
api_version = (
|
|
api_version
|
|
or os.getenv("AZURE_EMBEDDING_API_VERSION")
|
|
or os.getenv("OPENAI_API_VERSION")
|
|
)
|
|
|
|
openai_async_client = AsyncAzureOpenAI(
|
|
azure_endpoint=base_url,
|
|
azure_deployment=deployment,
|
|
api_key=api_key,
|
|
api_version=api_version,
|
|
)
|
|
|
|
response = await openai_async_client.embeddings.create(
|
|
model=model, input=texts, encoding_format="float"
|
|
)
|
|
return np.array([dp.embedding for dp in response.data])
|