mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-24 17:30:47 +00:00

This avoids issues of resource cleanup (too many open files) when dealing with massively parallel calls to the openai API since RAII in python is highly unreliable in such contexts.
430 lines
15 KiB
Python
430 lines
15 KiB
Python
from ..utils import verbose_debug, VERBOSE_DEBUG
|
|
import sys
|
|
import os
|
|
import logging
|
|
|
|
if sys.version_info < (3, 9):
|
|
from typing import AsyncIterator
|
|
else:
|
|
from collections.abc import AsyncIterator
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
|
|
|
# install specific modules
|
|
if not pm.is_installed("openai"):
|
|
pm.install("openai")
|
|
|
|
from openai import (
|
|
AsyncOpenAI,
|
|
APIConnectionError,
|
|
RateLimitError,
|
|
APITimeoutError,
|
|
)
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
retry_if_exception_type,
|
|
)
|
|
from lightrag.utils import (
|
|
wrap_embedding_func_with_attrs,
|
|
locate_json_string_body_from_string,
|
|
safe_unicode_decode,
|
|
logger,
|
|
)
|
|
from lightrag.types import GPTKeywordExtractionFormat
|
|
from lightrag.api import __api_version__
|
|
|
|
import numpy as np
|
|
from typing import Any, Union
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
# use the .env that is inside the current folder
|
|
# allows to use different .env file for each lightrag instance
|
|
# the OS environment variables take precedence over the .env file
|
|
load_dotenv(dotenv_path=".env", override=False)
|
|
|
|
|
|
class InvalidResponseError(Exception):
|
|
"""Custom exception class for triggering retry mechanism"""
|
|
|
|
pass
|
|
|
|
|
|
def create_openai_async_client(
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
client_configs: dict[str, Any] = None,
|
|
) -> AsyncOpenAI:
|
|
"""Create an AsyncOpenAI client with the given configuration.
|
|
|
|
Args:
|
|
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
|
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
|
These will override any default configurations but will be overridden by
|
|
explicit parameters (api_key, base_url).
|
|
|
|
Returns:
|
|
An AsyncOpenAI client instance.
|
|
"""
|
|
if not api_key:
|
|
api_key = os.environ["OPENAI_API_KEY"]
|
|
|
|
default_headers = {
|
|
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
if client_configs is None:
|
|
client_configs = {}
|
|
|
|
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
|
merged_configs = {
|
|
**client_configs,
|
|
"default_headers": default_headers,
|
|
"api_key": api_key,
|
|
}
|
|
|
|
if base_url is not None:
|
|
merged_configs["base_url"] = base_url
|
|
else:
|
|
merged_configs["base_url"] = os.environ.get(
|
|
"OPENAI_API_BASE", "https://api.openai.com/v1"
|
|
)
|
|
|
|
return AsyncOpenAI(**merged_configs)
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
retry=(
|
|
retry_if_exception_type(RateLimitError)
|
|
| retry_if_exception_type(APIConnectionError)
|
|
| retry_if_exception_type(APITimeoutError)
|
|
| retry_if_exception_type(InvalidResponseError)
|
|
),
|
|
)
|
|
async def openai_complete_if_cache(
|
|
model: str,
|
|
prompt: str,
|
|
system_prompt: str | None = None,
|
|
history_messages: list[dict[str, Any]] | None = None,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
token_tracker: Any | None = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Complete a prompt using OpenAI's API with caching support.
|
|
|
|
Args:
|
|
model: The OpenAI model to use.
|
|
prompt: The prompt to complete.
|
|
system_prompt: Optional system prompt to include.
|
|
history_messages: Optional list of previous messages in the conversation.
|
|
base_url: Optional base URL for the OpenAI API.
|
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
|
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
|
Special kwargs:
|
|
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
|
These will be passed to the client constructor but will be overridden by
|
|
explicit parameters (api_key, base_url).
|
|
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
|
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
|
|
|
Returns:
|
|
The completed text or an async iterator of text chunks if streaming.
|
|
|
|
Raises:
|
|
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
|
APIConnectionError: If there is a connection error with the OpenAI API.
|
|
RateLimitError: If the OpenAI API rate limit is exceeded.
|
|
APITimeoutError: If the OpenAI API request times out.
|
|
"""
|
|
if history_messages is None:
|
|
history_messages = []
|
|
|
|
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
|
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
|
logging.getLogger("openai").setLevel(logging.INFO)
|
|
|
|
# Extract client configuration options
|
|
client_configs = kwargs.pop("openai_client_configs", {})
|
|
|
|
# Create the OpenAI client
|
|
openai_async_client = create_openai_async_client(
|
|
api_key=api_key, base_url=base_url, client_configs=client_configs
|
|
)
|
|
|
|
# Remove special kwargs that shouldn't be passed to OpenAI
|
|
kwargs.pop("hashing_kv", None)
|
|
kwargs.pop("keyword_extraction", None)
|
|
|
|
# Prepare messages
|
|
messages: list[dict[str, Any]] = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.extend(history_messages)
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
logger.debug("===== Entering func of LLM =====")
|
|
logger.debug(f"Model: {model} Base URL: {base_url}")
|
|
logger.debug(f"Additional kwargs: {kwargs}")
|
|
logger.debug(f"Num of history messages: {len(history_messages)}")
|
|
verbose_debug(f"System prompt: {system_prompt}")
|
|
verbose_debug(f"Query: {prompt}")
|
|
logger.debug("===== Sending Query to LLM =====")
|
|
|
|
try:
|
|
async with openai_async_client:
|
|
if "response_format" in kwargs:
|
|
response = await openai_async_client.beta.chat.completions.parse(
|
|
model=model, messages=messages, **kwargs
|
|
)
|
|
else:
|
|
response = await openai_async_client.chat.completions.create(
|
|
model=model, messages=messages, **kwargs
|
|
)
|
|
except APIConnectionError as e:
|
|
logger.error(f"OpenAI API Connection Error: {e}")
|
|
raise
|
|
except RateLimitError as e:
|
|
logger.error(f"OpenAI API Rate Limit Error: {e}")
|
|
raise
|
|
except APITimeoutError as e:
|
|
logger.error(f"OpenAI API Timeout Error: {e}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}"
|
|
)
|
|
raise
|
|
|
|
if hasattr(response, "__aiter__"):
|
|
|
|
async def inner():
|
|
# Track if we've started iterating
|
|
iteration_started = False
|
|
try:
|
|
iteration_started = True
|
|
async for chunk in response:
|
|
# Check if choices exists and is not empty
|
|
if not hasattr(chunk, "choices") or not chunk.choices:
|
|
logger.warning(f"Received chunk without choices: {chunk}")
|
|
continue
|
|
|
|
# Check if delta exists and has content
|
|
if not hasattr(chunk.choices[0], "delta") or not hasattr(
|
|
chunk.choices[0].delta, "content"
|
|
):
|
|
logger.warning(
|
|
f"Received chunk without delta content: {chunk.choices[0]}"
|
|
)
|
|
continue
|
|
content = chunk.choices[0].delta.content
|
|
if content is None:
|
|
continue
|
|
if r"\u" in content:
|
|
content = safe_unicode_decode(content.encode("utf-8"))
|
|
yield content
|
|
except Exception as e:
|
|
logger.error(f"Error in stream response: {str(e)}")
|
|
# Try to clean up resources if possible
|
|
if (
|
|
iteration_started
|
|
and hasattr(response, "aclose")
|
|
and callable(getattr(response, "aclose", None))
|
|
):
|
|
try:
|
|
await response.aclose()
|
|
logger.debug("Successfully closed stream response after error")
|
|
except Exception as close_error:
|
|
logger.warning(
|
|
f"Failed to close stream response: {close_error}"
|
|
)
|
|
raise
|
|
finally:
|
|
# Ensure resources are released even if no exception occurs
|
|
if (
|
|
iteration_started
|
|
and hasattr(response, "aclose")
|
|
and callable(getattr(response, "aclose", None))
|
|
):
|
|
try:
|
|
await response.aclose()
|
|
logger.debug("Successfully closed stream response")
|
|
except Exception as close_error:
|
|
logger.warning(
|
|
f"Failed to close stream response in finally block: {close_error}"
|
|
)
|
|
|
|
return inner()
|
|
|
|
else:
|
|
if (
|
|
not response
|
|
or not response.choices
|
|
or not hasattr(response.choices[0], "message")
|
|
or not hasattr(response.choices[0].message, "content")
|
|
):
|
|
logger.error("Invalid response from OpenAI API")
|
|
raise InvalidResponseError("Invalid response from OpenAI API")
|
|
|
|
content = response.choices[0].message.content
|
|
|
|
if not content or content.strip() == "":
|
|
logger.error("Received empty content from OpenAI API")
|
|
raise InvalidResponseError("Received empty content from OpenAI API")
|
|
|
|
if r"\u" in content:
|
|
content = safe_unicode_decode(content.encode("utf-8"))
|
|
|
|
if token_tracker and hasattr(response, "usage"):
|
|
token_counts = {
|
|
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
|
|
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
|
}
|
|
token_tracker.add_usage(token_counts)
|
|
|
|
logger.debug(f"Response content len: {len(content)}")
|
|
verbose_debug(f"Response: {response}")
|
|
|
|
return content
|
|
|
|
|
|
async def openai_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> Union[str, AsyncIterator[str]]:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = "json"
|
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
|
return await openai_complete_if_cache(
|
|
model_name,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def gpt_4o_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
return await openai_complete_if_cache(
|
|
"gpt-4o",
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def gpt_4o_mini_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
return await openai_complete_if_cache(
|
|
"gpt-4o-mini",
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def nvidia_openai_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
result = await openai_complete_if_cache(
|
|
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
base_url="https://integrate.api.nvidia.com/v1",
|
|
**kwargs,
|
|
)
|
|
if keyword_extraction: # TODO: use JSON API
|
|
return locate_json_string_body_from_string(result)
|
|
return result
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=(
|
|
retry_if_exception_type(RateLimitError)
|
|
| retry_if_exception_type(APIConnectionError)
|
|
| retry_if_exception_type(APITimeoutError)
|
|
),
|
|
)
|
|
async def openai_embed(
|
|
texts: list[str],
|
|
model: str = "text-embedding-3-small",
|
|
base_url: str = None,
|
|
api_key: str = None,
|
|
client_configs: dict[str, Any] = None,
|
|
) -> np.ndarray:
|
|
"""Generate embeddings for a list of texts using OpenAI's API.
|
|
|
|
Args:
|
|
texts: List of texts to embed.
|
|
model: The OpenAI embedding model to use.
|
|
base_url: Optional base URL for the OpenAI API.
|
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
|
These will override any default configurations but will be overridden by
|
|
explicit parameters (api_key, base_url).
|
|
|
|
Returns:
|
|
A numpy array of embeddings, one per input text.
|
|
|
|
Raises:
|
|
APIConnectionError: If there is a connection error with the OpenAI API.
|
|
RateLimitError: If the OpenAI API rate limit is exceeded.
|
|
APITimeoutError: If the OpenAI API request times out.
|
|
"""
|
|
# Create the OpenAI client
|
|
openai_async_client = create_openai_async_client(
|
|
api_key=api_key, base_url=base_url, client_configs=client_configs
|
|
)
|
|
|
|
async with openai_async_client:
|
|
response = await openai_async_client.embeddings.create(
|
|
model=model, input=texts, encoding_format="float"
|
|
)
|
|
return np.array([dp.embedding for dp in response.data])
|