from ..utils import verbose_debug, VERBOSE_DEBUG import sys import os import logging if sys.version_info < (3, 9): from typing import AsyncIterator else: from collections.abc import AsyncIterator import pipmaster as pm # Pipmaster for dynamic library install # install specific modules if not pm.is_installed("openai"): pm.install("openai") from openai import ( AsyncOpenAI, APIConnectionError, RateLimitError, APITimeoutError, ) from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) from lightrag.utils import ( wrap_embedding_func_with_attrs, locate_json_string_body_from_string, safe_unicode_decode, logger, ) from lightrag.types import GPTKeywordExtractionFormat from lightrag.api import __api_version__ import numpy as np from typing import Any, Union from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) class InvalidResponseError(Exception): """Custom exception class for triggering retry mechanism""" pass def create_openai_async_client( api_key: str | None = None, base_url: str | None = None, client_configs: dict[str, Any] = None, ) -> AsyncOpenAI: """Create an AsyncOpenAI client with the given configuration. Args: api_key: OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. base_url: Base URL for the OpenAI API. If None, uses the default OpenAI API URL. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Returns: An AsyncOpenAI client instance. """ if not api_key: api_key = os.environ["OPENAI_API_KEY"] default_headers = { "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}", "Content-Type": "application/json", } if client_configs is None: client_configs = {} # Create a merged config dict with precedence: explicit params > client_configs > defaults merged_configs = { **client_configs, "default_headers": default_headers, "api_key": api_key, } if base_url is not None: merged_configs["base_url"] = base_url else: merged_configs["base_url"] = os.environ.get( "OPENAI_API_BASE", "https://api.openai.com/v1" ) return AsyncOpenAI(**merged_configs) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=( retry_if_exception_type(RateLimitError) | retry_if_exception_type(APIConnectionError) | retry_if_exception_type(APITimeoutError) | retry_if_exception_type(InvalidResponseError) ), ) async def openai_complete_if_cache( model: str, prompt: str, system_prompt: str | None = None, history_messages: list[dict[str, Any]] | None = None, base_url: str | None = None, api_key: str | None = None, token_tracker: Any | None = None, **kwargs: Any, ) -> str: """Complete a prompt using OpenAI's API with caching support. Args: model: The OpenAI model to use. prompt: The prompt to complete. system_prompt: Optional system prompt to include. history_messages: Optional list of previous messages in the conversation. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. **kwargs: Additional keyword arguments to pass to the OpenAI API. Special kwargs: - openai_client_configs: Dict of configuration options for the AsyncOpenAI client. These will be passed to the client constructor but will be overridden by explicit parameters (api_key, base_url). - hashing_kv: Will be removed from kwargs before passing to OpenAI. - keyword_extraction: Will be removed from kwargs before passing to OpenAI. Returns: The completed text or an async iterator of text chunks if streaming. Raises: InvalidResponseError: If the response from OpenAI is invalid or empty. APIConnectionError: If there is a connection error with the OpenAI API. RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ if history_messages is None: history_messages = [] # Set openai logger level to INFO when VERBOSE_DEBUG is off if not VERBOSE_DEBUG and logger.level == logging.DEBUG: logging.getLogger("openai").setLevel(logging.INFO) # Extract client configuration options client_configs = kwargs.pop("openai_client_configs", {}) # Create the OpenAI client openai_async_client = create_openai_async_client( api_key=api_key, base_url=base_url, client_configs=client_configs ) # Remove special kwargs that shouldn't be passed to OpenAI kwargs.pop("hashing_kv", None) kwargs.pop("keyword_extraction", None) # Prepare messages messages: list[dict[str, Any]] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) logger.debug("===== Entering func of LLM =====") logger.debug(f"Model: {model} Base URL: {base_url}") logger.debug(f"Additional kwargs: {kwargs}") logger.debug(f"Num of history messages: {len(history_messages)}") verbose_debug(f"System prompt: {system_prompt}") verbose_debug(f"Query: {prompt}") logger.debug("===== Sending Query to LLM =====") try: # Don't use async with context manager, use client directly if "response_format" in kwargs: response = await openai_async_client.beta.chat.completions.parse( model=model, messages=messages, **kwargs ) else: response = await openai_async_client.chat.completions.create( model=model, messages=messages, **kwargs ) except APIConnectionError as e: logger.error(f"OpenAI API Connection Error: {e}") await openai_async_client.close() # Ensure client is closed raise except RateLimitError as e: logger.error(f"OpenAI API Rate Limit Error: {e}") await openai_async_client.close() # Ensure client is closed raise except APITimeoutError as e: logger.error(f"OpenAI API Timeout Error: {e}") await openai_async_client.close() # Ensure client is closed raise except Exception as e: logger.error( f"OpenAI API Call Failed,\nModel: {model},\nParams: {kwargs}, Got: {e}" ) await openai_async_client.close() # Ensure client is closed raise if hasattr(response, "__aiter__"): async def inner(): # Track if we've started iterating iteration_started = False try: iteration_started = True async for chunk in response: # Check if choices exists and is not empty if not hasattr(chunk, "choices") or not chunk.choices: logger.warning(f"Received chunk without choices: {chunk}") continue # Check if delta exists and has content if not hasattr(chunk.choices[0], "delta") or not hasattr( chunk.choices[0].delta, "content" ): logger.warning( f"Received chunk without delta content: {chunk.choices[0]}" ) continue content = chunk.choices[0].delta.content if content is None: continue if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) yield content except Exception as e: logger.error(f"Error in stream response: {str(e)}") # Try to clean up resources if possible if ( iteration_started and hasattr(response, "aclose") and callable(getattr(response, "aclose", None)) ): try: await response.aclose() logger.debug("Successfully closed stream response after error") except Exception as close_error: logger.warning( f"Failed to close stream response: {close_error}" ) # Ensure client is closed in case of exception await openai_async_client.close() raise finally: # Ensure resources are released even if no exception occurs if ( iteration_started and hasattr(response, "aclose") and callable(getattr(response, "aclose", None)) ): try: await response.aclose() logger.debug("Successfully closed stream response") except Exception as close_error: logger.warning( f"Failed to close stream response in finally block: {close_error}" ) # This prevents resource leaks since the caller doesn't handle closing try: await openai_async_client.close() logger.debug( "Successfully closed OpenAI client for streaming response" ) except Exception as client_close_error: logger.warning( f"Failed to close OpenAI client in streaming finally block: {client_close_error}" ) return inner() else: try: if ( not response or not response.choices or not hasattr(response.choices[0], "message") or not hasattr(response.choices[0].message, "content") ): logger.error("Invalid response from OpenAI API") await openai_async_client.close() # Ensure client is closed raise InvalidResponseError("Invalid response from OpenAI API") content = response.choices[0].message.content if not content or content.strip() == "": logger.error("Received empty content from OpenAI API") await openai_async_client.close() # Ensure client is closed raise InvalidResponseError("Received empty content from OpenAI API") if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) if token_tracker and hasattr(response, "usage"): token_counts = { "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), "completion_tokens": getattr( response.usage, "completion_tokens", 0 ), "total_tokens": getattr(response.usage, "total_tokens", 0), } token_tracker.add_usage(token_counts) logger.debug(f"Response content len: {len(content)}") verbose_debug(f"Response: {response}") return content finally: # Ensure client is closed in all cases for non-streaming responses await openai_async_client.close() async def openai_complete( prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs, ) -> Union[str, AsyncIterator[str]]: if history_messages is None: history_messages = [] keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = "json" model_name = kwargs["hashing_kv"].global_config["llm_model_name"] return await openai_complete_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) async def gpt_4o_complete( prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs, ) -> str: if history_messages is None: history_messages = [] keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o", prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) async def gpt_4o_mini_complete( prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs, ) -> str: if history_messages is None: history_messages = [] keyword_extraction = kwargs.pop("keyword_extraction", None) if keyword_extraction: kwargs["response_format"] = GPTKeywordExtractionFormat return await openai_complete_if_cache( "gpt-4o-mini", prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) async def nvidia_openai_complete( prompt, system_prompt=None, history_messages=None, keyword_extraction=False, **kwargs, ) -> str: if history_messages is None: history_messages = [] keyword_extraction = kwargs.pop("keyword_extraction", None) result = await openai_complete_if_cache( "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k prompt, system_prompt=system_prompt, history_messages=history_messages, base_url="https://integrate.api.nvidia.com/v1", **kwargs, ) if keyword_extraction: # TODO: use JSON API return locate_json_string_body_from_string(result) return result @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), retry=( retry_if_exception_type(RateLimitError) | retry_if_exception_type(APIConnectionError) | retry_if_exception_type(APITimeoutError) ), ) async def openai_embed( texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None, client_configs: dict[str, Any] = None, ) -> np.ndarray: """Generate embeddings for a list of texts using OpenAI's API. Args: texts: List of texts to embed. model: The OpenAI embedding model to use. base_url: Optional base URL for the OpenAI API. api_key: Optional OpenAI API key. If None, uses the OPENAI_API_KEY environment variable. client_configs: Additional configuration options for the AsyncOpenAI client. These will override any default configurations but will be overridden by explicit parameters (api_key, base_url). Returns: A numpy array of embeddings, one per input text. Raises: APIConnectionError: If there is a connection error with the OpenAI API. RateLimitError: If the OpenAI API rate limit is exceeded. APITimeoutError: If the OpenAI API request times out. """ # Create the OpenAI client openai_async_client = create_openai_async_client( api_key=api_key, base_url=base_url, client_configs=client_configs ) async with openai_async_client: response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" ) return np.array([dp.embedding for dp in response.data])