mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-21 20:44:16 +00:00
This parameter is no longer used. Its removal simplifies the API and clarifies that token length management is handled by upstream text chunking logic rather than the embedding wrapper.
481 lines
17 KiB
Python
481 lines
17 KiB
Python
from ..utils import verbose_debug, VERBOSE_DEBUG
|
|
import sys
|
|
import os
|
|
import logging
|
|
|
|
if sys.version_info < (3, 9):
|
|
from typing import AsyncIterator
|
|
else:
|
|
from collections.abc import AsyncIterator
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
|
|
|
# install specific modules
|
|
if not pm.is_installed("openai"):
|
|
pm.install("openai")
|
|
|
|
from openai import (
|
|
AsyncOpenAI,
|
|
APIConnectionError,
|
|
RateLimitError,
|
|
APITimeoutError,
|
|
)
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
retry_if_exception_type,
|
|
)
|
|
from lightrag.utils import (
|
|
wrap_embedding_func_with_attrs,
|
|
locate_json_string_body_from_string,
|
|
safe_unicode_decode,
|
|
logger,
|
|
)
|
|
from lightrag.types import GPTKeywordExtractionFormat
|
|
from lightrag.api import __api_version__
|
|
|
|
import numpy as np
|
|
from typing import Any, Union
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
# use the .env that is inside the current folder
|
|
# allows to use different .env file for each lightrag instance
|
|
# the OS environment variables take precedence over the .env file
|
|
load_dotenv(dotenv_path=".env", override=False)
|
|
|
|
|
|
class InvalidResponseError(Exception):
|
|
"""Custom exception class for triggering retry mechanism"""
|
|
|
|
pass
|
|
|
|
|
|
def create_openai_async_client(
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
client_configs: dict[str, Any] = None,
|
|
) -> AsyncOpenAI:
|
|
"""Create an AsyncOpenAI client with the given configuration.
|
|
|
|
Args:
|
|
api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
|
base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL.
|
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
|
These will override any default configurations but will be overridden by
|
|
explicit parameters (api_key, base_url).
|
|
|
|
Returns:
|
|
An AsyncOpenAI client instance.
|
|
"""
|
|
if not api_key:
|
|
api_key = os.environ["OPENAI_API_KEY"]
|
|
|
|
default_headers = {
|
|
"User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
if client_configs is None:
|
|
client_configs = {}
|
|
|
|
# Create a merged config dict with precedence: explicit params > client_configs > defaults
|
|
merged_configs = {
|
|
**client_configs,
|
|
"default_headers": default_headers,
|
|
"api_key": api_key,
|
|
}
|
|
|
|
if base_url is not None:
|
|
merged_configs["base_url"] = base_url
|
|
else:
|
|
merged_configs["base_url"] = os.environ.get(
|
|
"OPENAI_API_BASE", "https://api.openai.com/v1"
|
|
)
|
|
|
|
return AsyncOpenAI(**merged_configs)
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
retry=(
|
|
retry_if_exception_type(RateLimitError)
|
|
| retry_if_exception_type(APIConnectionError)
|
|
| retry_if_exception_type(APITimeoutError)
|
|
| retry_if_exception_type(InvalidResponseError)
|
|
),
|
|
)
|
|
async def openai_complete_if_cache(
|
|
model: str,
|
|
prompt: str,
|
|
system_prompt: str | None = None,
|
|
history_messages: list[dict[str, Any]] | None = None,
|
|
base_url: str | None = None,
|
|
api_key: str | None = None,
|
|
token_tracker: Any | None = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Complete a prompt using OpenAI's API with caching support.
|
|
|
|
Args:
|
|
model: The OpenAI model to use.
|
|
prompt: The prompt to complete.
|
|
system_prompt: Optional system prompt to include.
|
|
history_messages: Optional list of previous messages in the conversation.
|
|
base_url: Optional base URL for the OpenAI API.
|
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
|
**kwargs: Additional keyword arguments to pass to the OpenAI API.
|
|
Special kwargs:
|
|
- openai_client_configs: Dict of configuration options for the AsyncOpenAI client.
|
|
These will be passed to the client constructor but will be overridden by
|
|
explicit parameters (api_key, base_url).
|
|
- hashing_kv: Will be removed from kwargs before passing to OpenAI.
|
|
- keyword_extraction: Will be removed from kwargs before passing to OpenAI.
|
|
|
|
Returns:
|
|
The completed text or an async iterator of text chunks if streaming.
|
|
|
|
Raises:
|
|
InvalidResponseError: If the response from OpenAI is invalid or empty.
|
|
APIConnectionError: If there is a connection error with the OpenAI API.
|
|
RateLimitError: If the OpenAI API rate limit is exceeded.
|
|
APITimeoutError: If the OpenAI API request times out.
|
|
"""
|
|
if history_messages is None:
|
|
history_messages = []
|
|
|
|
# Set openai logger level to INFO when VERBOSE_DEBUG is off
|
|
if not VERBOSE_DEBUG and logger.level == logging.DEBUG:
|
|
logging.getLogger("openai").setLevel(logging.INFO)
|
|
|
|
# Extract client configuration options
|
|
client_configs = kwargs.pop("openai_client_configs", {})
|
|
|
|
# Create the OpenAI client
|
|
openai_async_client = create_openai_async_client(
|
|
api_key=api_key, base_url=base_url, client_configs=client_configs
|
|
)
|
|
|
|
# Remove special kwargs that shouldn't be passed to OpenAI
|
|
kwargs.pop("hashing_kv", None)
|
|
kwargs.pop("keyword_extraction", None)
|
|
|
|
# Prepare messages
|
|
messages: list[dict[str, Any]] = []
|
|
if system_prompt:
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
messages.extend(history_messages)
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
logger.debug("===== Entering func of LLM =====")
|
|
logger.debug(f"Model: {model} Base URL: {base_url}")
|
|
logger.debug(f"Additional kwargs: {kwargs}")
|
|
logger.debug(f"Num of history messages: {len(history_messages)}")
|
|
verbose_debug(f"System prompt: {system_prompt}")
|
|
verbose_debug(f"Query: {prompt}")
|
|
logger.debug("===== Sending Query to LLM =====")
|
|
|
|
messages = kwargs.pop("messages", messages)
|
|
|
|
try:
|
|
# Don't use async with context manager, use client directly
|
|
if "response_format" in kwargs:
|
|
response = await openai_async_client.beta.chat.completions.parse(
|
|
model=model, messages=messages, **kwargs
|
|
)
|
|
else:
|
|
response = await openai_async_client.chat.completions.create(
|
|
model=model, messages=messages, **kwargs
|
|
)
|
|
except APIConnectionError as e:
|
|
logger.error(f"OpenAI API Connection Error: {e}")
|
|
await openai_async_client.close() # Ensure client is closed
|
|
raise
|
|
except RateLimitError as e:
|
|
logger.error(f"OpenAI API Rate Limit Error: {e}")
|
|
await openai_async_client.close() # Ensure client is closed
|
|
raise
|
|
except APITimeoutError as e:
|
|
logger.error(f"OpenAI API Timeout Error: {e}")
|
|
await openai_async_client.close() # Ensure client is closed
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}"
|
|
)
|
|
await openai_async_client.close() # Ensure client is closed
|
|
raise
|
|
|
|
if hasattr(response, "__aiter__"):
|
|
|
|
async def inner():
|
|
# Track if we've started iterating
|
|
iteration_started = False
|
|
final_chunk_usage = None
|
|
|
|
try:
|
|
iteration_started = True
|
|
async for chunk in response:
|
|
# Check if this chunk has usage information (final chunk)
|
|
if hasattr(chunk, "usage") and chunk.usage:
|
|
final_chunk_usage = chunk.usage
|
|
logger.debug(
|
|
f"Received usage info in streaming chunk: {chunk.usage}"
|
|
)
|
|
|
|
# Check if choices exists and is not empty
|
|
if not hasattr(chunk, "choices") or not chunk.choices:
|
|
logger.warning(f"Received chunk without choices: {chunk}")
|
|
continue
|
|
|
|
# Check if delta exists and has content
|
|
if not hasattr(chunk.choices[0], "delta") or not hasattr(
|
|
chunk.choices[0].delta, "content"
|
|
):
|
|
# This might be the final chunk, continue to check for usage
|
|
continue
|
|
|
|
content = chunk.choices[0].delta.content
|
|
if content is None:
|
|
continue
|
|
if r"\u" in content:
|
|
content = safe_unicode_decode(content.encode("utf-8"))
|
|
|
|
yield content
|
|
|
|
# After streaming is complete, track token usage
|
|
if token_tracker and final_chunk_usage:
|
|
# Use actual usage from the API
|
|
token_counts = {
|
|
"prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(
|
|
final_chunk_usage, "completion_tokens", 0
|
|
),
|
|
"total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
|
|
}
|
|
token_tracker.add_usage(token_counts)
|
|
logger.debug(f"Streaming token usage (from API): {token_counts}")
|
|
elif token_tracker:
|
|
logger.debug("No usage information available in streaming response")
|
|
except Exception as e:
|
|
logger.error(f"Error in stream response: {str(e)}")
|
|
# Try to clean up resources if possible
|
|
if (
|
|
iteration_started
|
|
and hasattr(response, "aclose")
|
|
and callable(getattr(response, "aclose", None))
|
|
):
|
|
try:
|
|
await response.aclose()
|
|
logger.debug("Successfully closed stream response after error")
|
|
except Exception as close_error:
|
|
logger.warning(
|
|
f"Failed to close stream response: {close_error}"
|
|
)
|
|
# Ensure client is closed in case of exception
|
|
await openai_async_client.close()
|
|
raise
|
|
finally:
|
|
# Ensure resources are released even if no exception occurs
|
|
if (
|
|
iteration_started
|
|
and hasattr(response, "aclose")
|
|
and callable(getattr(response, "aclose", None))
|
|
):
|
|
try:
|
|
await response.aclose()
|
|
logger.debug("Successfully closed stream response")
|
|
except Exception as close_error:
|
|
logger.warning(
|
|
f"Failed to close stream response in finally block: {close_error}"
|
|
)
|
|
|
|
# This prevents resource leaks since the caller doesn't handle closing
|
|
try:
|
|
await openai_async_client.close()
|
|
logger.debug(
|
|
"Successfully closed OpenAI client for streaming response"
|
|
)
|
|
except Exception as client_close_error:
|
|
logger.warning(
|
|
f"Failed to close OpenAI client in streaming finally block: {client_close_error}"
|
|
)
|
|
|
|
return inner()
|
|
|
|
else:
|
|
try:
|
|
if (
|
|
not response
|
|
or not response.choices
|
|
or not hasattr(response.choices[0], "message")
|
|
or not hasattr(response.choices[0].message, "content")
|
|
):
|
|
logger.error("Invalid response from OpenAI API")
|
|
await openai_async_client.close() # Ensure client is closed
|
|
raise InvalidResponseError("Invalid response from OpenAI API")
|
|
|
|
content = response.choices[0].message.content
|
|
|
|
if not content or content.strip() == "":
|
|
logger.error("Received empty content from OpenAI API")
|
|
await openai_async_client.close() # Ensure client is closed
|
|
raise InvalidResponseError("Received empty content from OpenAI API")
|
|
|
|
if r"\u" in content:
|
|
content = safe_unicode_decode(content.encode("utf-8"))
|
|
|
|
if token_tracker and hasattr(response, "usage"):
|
|
token_counts = {
|
|
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
|
"completion_tokens": getattr(
|
|
response.usage, "completion_tokens", 0
|
|
),
|
|
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
|
}
|
|
token_tracker.add_usage(token_counts)
|
|
|
|
logger.debug(f"Response content len: {len(content)}")
|
|
verbose_debug(f"Response: {response}")
|
|
|
|
return content
|
|
finally:
|
|
# Ensure client is closed in all cases for non-streaming responses
|
|
await openai_async_client.close()
|
|
|
|
|
|
async def openai_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> Union[str, AsyncIterator[str]]:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = "json"
|
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
|
return await openai_complete_if_cache(
|
|
model_name,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def gpt_4o_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
return await openai_complete_if_cache(
|
|
"gpt-4o",
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def gpt_4o_mini_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
if keyword_extraction:
|
|
kwargs["response_format"] = GPTKeywordExtractionFormat
|
|
return await openai_complete_if_cache(
|
|
"gpt-4o-mini",
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def nvidia_openai_complete(
|
|
prompt,
|
|
system_prompt=None,
|
|
history_messages=None,
|
|
keyword_extraction=False,
|
|
**kwargs,
|
|
) -> str:
|
|
if history_messages is None:
|
|
history_messages = []
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
result = await openai_complete_if_cache(
|
|
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
base_url="https://integrate.api.nvidia.com/v1",
|
|
**kwargs,
|
|
)
|
|
if keyword_extraction: # TODO: use JSON API
|
|
return locate_json_string_body_from_string(result)
|
|
return result
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1536)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=(
|
|
retry_if_exception_type(RateLimitError)
|
|
| retry_if_exception_type(APIConnectionError)
|
|
| retry_if_exception_type(APITimeoutError)
|
|
),
|
|
)
|
|
async def openai_embed(
|
|
texts: list[str],
|
|
model: str = "text-embedding-3-small",
|
|
base_url: str = None,
|
|
api_key: str = None,
|
|
client_configs: dict[str, Any] = None,
|
|
) -> np.ndarray:
|
|
"""Generate embeddings for a list of texts using OpenAI's API.
|
|
|
|
Args:
|
|
texts: List of texts to embed.
|
|
model: The OpenAI embedding model to use.
|
|
base_url: Optional base URL for the OpenAI API.
|
|
api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable.
|
|
client_configs: Additional configuration options for the AsyncOpenAI client.
|
|
These will override any default configurations but will be overridden by
|
|
explicit parameters (api_key, base_url).
|
|
|
|
Returns:
|
|
A numpy array of embeddings, one per input text.
|
|
|
|
Raises:
|
|
APIConnectionError: If there is a connection error with the OpenAI API.
|
|
RateLimitError: If the OpenAI API rate limit is exceeded.
|
|
APITimeoutError: If the OpenAI API request times out.
|
|
"""
|
|
# Create the OpenAI client
|
|
openai_async_client = create_openai_async_client(
|
|
api_key=api_key, base_url=base_url, client_configs=client_configs
|
|
)
|
|
|
|
async with openai_async_client:
|
|
response = await openai_async_client.embeddings.create(
|
|
model=model, input=texts, encoding_format="float"
|
|
)
|
|
return np.array([dp.embedding for dp in response.data])
|