mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-11-04 03:39:35 +00:00 
			
		
		
		
	Merge pull request #1167 from omdivyatej/om-pr
Feature: Dynamic LLM Selection via QueryParam for Optimized Performance
This commit is contained in:
		
						commit
						ec15d5a5af
					
				
							
								
								
									
										88
									
								
								examples/lightrag_multi_model_all_modes_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								examples/lightrag_multi_model_all_modes_demo.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,88 @@
 | 
			
		||||
import os
 | 
			
		||||
import asyncio
 | 
			
		||||
from lightrag import LightRAG, QueryParam
 | 
			
		||||
from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete, openai_embed
 | 
			
		||||
from lightrag.kg.shared_storage import initialize_pipeline_status
 | 
			
		||||
 | 
			
		||||
WORKING_DIR = "./lightrag_demo"
 | 
			
		||||
 | 
			
		||||
if not os.path.exists(WORKING_DIR):
 | 
			
		||||
    os.mkdir(WORKING_DIR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def initialize_rag():
 | 
			
		||||
    rag = LightRAG(
 | 
			
		||||
        working_dir=WORKING_DIR,
 | 
			
		||||
        embedding_func=openai_embed,
 | 
			
		||||
        llm_model_func=gpt_4o_mini_complete,  # Default model for queries
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    await rag.initialize_storages()
 | 
			
		||||
    await initialize_pipeline_status()
 | 
			
		||||
 | 
			
		||||
    return rag
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    # Initialize RAG instance
 | 
			
		||||
    rag = asyncio.run(initialize_rag())
 | 
			
		||||
 | 
			
		||||
    # Load the data
 | 
			
		||||
    with open("./book.txt", "r", encoding="utf-8") as f:
 | 
			
		||||
        rag.insert(f.read())
 | 
			
		||||
 | 
			
		||||
    # Query with naive mode (default model)
 | 
			
		||||
    print("--- NAIVE mode ---")
 | 
			
		||||
    print(
 | 
			
		||||
        rag.query(
 | 
			
		||||
            "What are the main themes in this story?", param=QueryParam(mode="naive")
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Query with local mode (default model)
 | 
			
		||||
    print("\n--- LOCAL mode ---")
 | 
			
		||||
    print(
 | 
			
		||||
        rag.query(
 | 
			
		||||
            "What are the main themes in this story?", param=QueryParam(mode="local")
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Query with global mode (default model)
 | 
			
		||||
    print("\n--- GLOBAL mode ---")
 | 
			
		||||
    print(
 | 
			
		||||
        rag.query(
 | 
			
		||||
            "What are the main themes in this story?", param=QueryParam(mode="global")
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Query with hybrid mode (default model)
 | 
			
		||||
    print("\n--- HYBRID mode ---")
 | 
			
		||||
    print(
 | 
			
		||||
        rag.query(
 | 
			
		||||
            "What are the main themes in this story?", param=QueryParam(mode="hybrid")
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Query with mix mode (default model)
 | 
			
		||||
    print("\n--- MIX mode ---")
 | 
			
		||||
    print(
 | 
			
		||||
        rag.query(
 | 
			
		||||
            "What are the main themes in this story?", param=QueryParam(mode="mix")
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Query with a custom model (gpt-4o) for a more complex question
 | 
			
		||||
    print("\n--- Using custom model for complex analysis ---")
 | 
			
		||||
    print(
 | 
			
		||||
        rag.query(
 | 
			
		||||
            "How does the character development reflect Victorian-era attitudes?",
 | 
			
		||||
            param=QueryParam(
 | 
			
		||||
                mode="global",
 | 
			
		||||
                model_func=gpt_4o_complete,  # Override default model with more capable one
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -10,6 +10,7 @@ from typing import (
 | 
			
		||||
    Literal,
 | 
			
		||||
    TypedDict,
 | 
			
		||||
    TypeVar,
 | 
			
		||||
    Callable,
 | 
			
		||||
)
 | 
			
		||||
import numpy as np
 | 
			
		||||
from .utils import EmbeddingFunc
 | 
			
		||||
@ -84,6 +85,12 @@ class QueryParam:
 | 
			
		||||
    ids: list[str] | None = None
 | 
			
		||||
    """List of ids to filter the results."""
 | 
			
		||||
 | 
			
		||||
    model_func: Callable[..., object] | None = None
 | 
			
		||||
    """Optional override for the LLM model function to use for this specific query.
 | 
			
		||||
    If provided, this will be used instead of the global model function.
 | 
			
		||||
    This allows using different models for different query modes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class StorageNameSpace(ABC):
 | 
			
		||||
 | 
			
		||||
@ -1330,11 +1330,15 @@ class LightRAG:
 | 
			
		||||
        Args:
 | 
			
		||||
            query (str): The query to be executed.
 | 
			
		||||
            param (QueryParam): Configuration parameters for query execution.
 | 
			
		||||
                If param.model_func is provided, it will be used instead of the global model.
 | 
			
		||||
            prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: The result of the query execution.
 | 
			
		||||
        """
 | 
			
		||||
        # If a custom model is provided in param, temporarily update global config
 | 
			
		||||
        global_config = asdict(self)
 | 
			
		||||
 | 
			
		||||
        if param.mode in ["local", "global", "hybrid"]:
 | 
			
		||||
            response = await kg_query(
 | 
			
		||||
                query.strip(),
 | 
			
		||||
@ -1343,7 +1347,7 @@ class LightRAG:
 | 
			
		||||
                self.relationships_vdb,
 | 
			
		||||
                self.text_chunks,
 | 
			
		||||
                param,
 | 
			
		||||
                asdict(self),
 | 
			
		||||
                global_config,
 | 
			
		||||
                hashing_kv=self.llm_response_cache,  # Directly use llm_response_cache
 | 
			
		||||
                system_prompt=system_prompt,
 | 
			
		||||
            )
 | 
			
		||||
@ -1353,7 +1357,7 @@ class LightRAG:
 | 
			
		||||
                self.chunks_vdb,
 | 
			
		||||
                self.text_chunks,
 | 
			
		||||
                param,
 | 
			
		||||
                asdict(self),
 | 
			
		||||
                global_config,
 | 
			
		||||
                hashing_kv=self.llm_response_cache,  # Directly use llm_response_cache
 | 
			
		||||
                system_prompt=system_prompt,
 | 
			
		||||
            )
 | 
			
		||||
@ -1366,7 +1370,7 @@ class LightRAG:
 | 
			
		||||
                self.chunks_vdb,
 | 
			
		||||
                self.text_chunks,
 | 
			
		||||
                param,
 | 
			
		||||
                asdict(self),
 | 
			
		||||
                global_config,
 | 
			
		||||
                hashing_kv=self.llm_response_cache,  # Directly use llm_response_cache
 | 
			
		||||
                system_prompt=system_prompt,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -705,7 +705,11 @@ async def kg_query(
 | 
			
		||||
    system_prompt: str | None = None,
 | 
			
		||||
) -> str | AsyncIterator[str]:
 | 
			
		||||
    # Handle cache
 | 
			
		||||
    use_model_func = global_config["llm_model_func"]
 | 
			
		||||
    use_model_func = (
 | 
			
		||||
        query_param.model_func
 | 
			
		||||
        if query_param.model_func
 | 
			
		||||
        else global_config["llm_model_func"]
 | 
			
		||||
    )
 | 
			
		||||
    args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
 | 
			
		||||
    cached_response, quantized, min_val, max_val = await handle_cache(
 | 
			
		||||
        hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 | 
			
		||||
@ -866,7 +870,9 @@ async def extract_keywords_only(
 | 
			
		||||
    logger.debug(f"[kg_query]Prompt Tokens: {len_of_prompts}")
 | 
			
		||||
 | 
			
		||||
    # 5. Call the LLM for keyword extraction
 | 
			
		||||
    use_model_func = global_config["llm_model_func"]
 | 
			
		||||
    use_model_func = (
 | 
			
		||||
        param.model_func if param.model_func else global_config["llm_model_func"]
 | 
			
		||||
    )
 | 
			
		||||
    result = await use_model_func(kw_prompt, keyword_extraction=True)
 | 
			
		||||
 | 
			
		||||
    # 6. Parse out JSON from the LLM response
 | 
			
		||||
@ -926,7 +932,11 @@ async def mix_kg_vector_query(
 | 
			
		||||
    3. Combining both results for comprehensive answer generation
 | 
			
		||||
    """
 | 
			
		||||
    # 1. Cache handling
 | 
			
		||||
    use_model_func = global_config["llm_model_func"]
 | 
			
		||||
    use_model_func = (
 | 
			
		||||
        query_param.model_func
 | 
			
		||||
        if query_param.model_func
 | 
			
		||||
        else global_config["llm_model_func"]
 | 
			
		||||
    )
 | 
			
		||||
    args_hash = compute_args_hash("mix", query, cache_type="query")
 | 
			
		||||
    cached_response, quantized, min_val, max_val = await handle_cache(
 | 
			
		||||
        hashing_kv, args_hash, query, "mix", cache_type="query"
 | 
			
		||||
@ -1731,7 +1741,11 @@ async def naive_query(
 | 
			
		||||
    system_prompt: str | None = None,
 | 
			
		||||
) -> str | AsyncIterator[str]:
 | 
			
		||||
    # Handle cache
 | 
			
		||||
    use_model_func = global_config["llm_model_func"]
 | 
			
		||||
    use_model_func = (
 | 
			
		||||
        query_param.model_func
 | 
			
		||||
        if query_param.model_func
 | 
			
		||||
        else global_config["llm_model_func"]
 | 
			
		||||
    )
 | 
			
		||||
    args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
 | 
			
		||||
    cached_response, quantized, min_val, max_val = await handle_cache(
 | 
			
		||||
        hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 | 
			
		||||
@ -1850,7 +1864,11 @@ async def kg_query_with_keywords(
 | 
			
		||||
    # ---------------------------
 | 
			
		||||
    # 1) Handle potential cache for query results
 | 
			
		||||
    # ---------------------------
 | 
			
		||||
    use_model_func = global_config["llm_model_func"]
 | 
			
		||||
    use_model_func = (
 | 
			
		||||
        query_param.model_func
 | 
			
		||||
        if query_param.model_func
 | 
			
		||||
        else global_config["llm_model_func"]
 | 
			
		||||
    )
 | 
			
		||||
    args_hash = compute_args_hash(query_param.mode, query, cache_type="query")
 | 
			
		||||
    cached_response, quantized, min_val, max_val = await handle_cache(
 | 
			
		||||
        hashing_kv, args_hash, query, query_param.mode, cache_type="query"
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user