import pipmaster as pm from llama_index.core.llms import ( ChatMessage, MessageRole, ChatResponse, ) from typing import List, Optional from lightrag.utils import logger # Install required dependencies if not pm.is_installed("llama-index"): pm.install("llama-index") from llama_index.core.embeddings import BaseEmbedding from llama_index.core.settings import Settings as LlamaIndexSettings from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) from lightrag.utils import ( wrap_embedding_func_with_attrs, locate_json_string_body_from_string, ) from lightrag.exceptions import ( APIConnectionError, RateLimitError, APITimeoutError, ) import numpy as np def configure_llama_index(settings: LlamaIndexSettings = None, **kwargs): """ Configure LlamaIndex settings. Args: settings: LlamaIndex Settings instance. If None, uses default settings. **kwargs: Additional settings to override/configure """ if settings is None: settings = LlamaIndexSettings() # Update settings with any provided kwargs for key, value in kwargs.items(): if hasattr(settings, key): setattr(settings, key, value) else: logger.warning(f"Unknown LlamaIndex setting: {key}") # Set as global settings LlamaIndexSettings.set_global(settings) return settings def format_chat_messages(messages): """Format chat messages into LlamaIndex format.""" formatted_messages = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": formatted_messages.append( ChatMessage(role=MessageRole.SYSTEM, content=content) ) elif role == "assistant": formatted_messages.append( ChatMessage(role=MessageRole.ASSISTANT, content=content) ) elif role == "user": formatted_messages.append( ChatMessage(role=MessageRole.USER, content=content) ) else: logger.warning(f"Unknown role {role}, treating as user message") formatted_messages.append( ChatMessage(role=MessageRole.USER, content=content) ) return formatted_messages @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type( (RateLimitError, APIConnectionError, APITimeoutError) ), ) async def llama_index_complete_if_cache( model: str, prompt: str, system_prompt: Optional[str] = None, history_messages: List[dict] = [], chat_kwargs = {}, ) -> str: """Complete the prompt using LlamaIndex.""" try: # Format messages for chat formatted_messages = [] # Add system message if provided if system_prompt: formatted_messages.append( ChatMessage(role=MessageRole.SYSTEM, content=system_prompt) ) # Add history messages for msg in history_messages: formatted_messages.append( ChatMessage( role=MessageRole.USER if msg["role"] == "user" else MessageRole.ASSISTANT, content=msg["content"], ) ) # Add current prompt formatted_messages.append(ChatMessage(role=MessageRole.USER, content=prompt)) response: ChatResponse = await model.achat(messages=formatted_messages, **chat_kwargs) # In newer versions, the response is in message.content content = response.message.content return content except Exception as e: logger.error(f"Error in llama_index_complete_if_cache: {str(e)}") raise async def llama_index_complete( prompt, system_prompt=None, history_messages=None, keyword_extraction=False, settings: LlamaIndexSettings = None, **kwargs, ) -> str: """ Main completion function for LlamaIndex Args: prompt: Input prompt system_prompt: Optional system prompt history_messages: Optional chat history keyword_extraction: Whether to extract keywords from response settings: Optional LlamaIndex settings **kwargs: Additional arguments """ if history_messages is None: history_messages = [] keyword_extraction = kwargs.pop("keyword_extraction", None) result = await llama_index_complete_if_cache( kwargs.get("llm_instance"), prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) if keyword_extraction: return locate_json_string_body_from_string(result) return result @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=60), retry=retry_if_exception_type( (RateLimitError, APIConnectionError, APITimeoutError) ), ) async def llama_index_embed( texts: list[str], embed_model: BaseEmbedding = None, settings: LlamaIndexSettings = None, **kwargs, ) -> np.ndarray: """ Generate embeddings using LlamaIndex Args: texts: List of texts to embed embed_model: LlamaIndex embedding model settings: Optional LlamaIndex settings **kwargs: Additional arguments """ if settings: configure_llama_index(settings) if embed_model is None: raise ValueError("embed_model must be provided") # Use _get_text_embeddings for batch processing embeddings = embed_model._get_text_embeddings(texts) return np.array(embeddings)