| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | import os | 
					
						
							|  |  |  | import asyncio | 
					
						
							|  |  |  | from lightrag import LightRAG, QueryParam | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | from lightrag.llm import ( | 
					
						
							|  |  |  |     openai_complete_if_cache, | 
					
						
							|  |  |  |     nvidia_openai_embedding, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | from lightrag.utils import EmbeddingFunc | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | # for custom llm_model_func | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | from lightrag.utils import locate_json_string_body_from_string | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | WORKING_DIR = "./dickens" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if not os.path.exists(WORKING_DIR): | 
					
						
							|  |  |  |     os.mkdir(WORKING_DIR) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | # some method to use your API key (choose one) | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | # NVIDIA_OPENAI_API_KEY = os.getenv("NVIDIA_OPENAI_API_KEY") | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | NVIDIA_OPENAI_API_KEY = "nvapi-xxxx"  # your api key | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | 
 | 
					
						
							|  |  |  | # using pre-defined function for nvidia LLM API. OpenAI compatible | 
					
						
							|  |  |  | # llm_model_func = nvidia_openai_complete | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # If you trying to make custom llm_model_func to use llm model on NVIDIA API like other example: | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | async def llm_model_func( | 
					
						
							|  |  |  |     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | 
					
						
							|  |  |  | ) -> str: | 
					
						
							|  |  |  |     result = await openai_complete_if_cache( | 
					
						
							|  |  |  |         "nvidia/llama-3.1-nemotron-70b-instruct", | 
					
						
							|  |  |  |         prompt, | 
					
						
							|  |  |  |         system_prompt=system_prompt, | 
					
						
							|  |  |  |         history_messages=history_messages, | 
					
						
							|  |  |  |         api_key=NVIDIA_OPENAI_API_KEY, | 
					
						
							|  |  |  |         base_url="https://integrate.api.nvidia.com/v1", | 
					
						
							|  |  |  |         **kwargs, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     if keyword_extraction: | 
					
						
							|  |  |  |         return locate_json_string_body_from_string(result) | 
					
						
							|  |  |  |     return result | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # custom embedding | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | nvidia_embed_model = "nvidia/nv-embedqa-e5-v5" | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | async def indexing_embedding_func(texts: list[str]) -> np.ndarray: | 
					
						
							|  |  |  |     return await nvidia_openai_embedding( | 
					
						
							|  |  |  |         texts, | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |         model=nvidia_embed_model,  # maximum 512 token | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |         # model="nvidia/llama-3.2-nv-embedqa-1b-v1", | 
					
						
							|  |  |  |         api_key=NVIDIA_OPENAI_API_KEY, | 
					
						
							|  |  |  |         base_url="https://integrate.api.nvidia.com/v1", | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |         input_type="passage", | 
					
						
							|  |  |  |         trunc="END",  # handling on server side if input token is longer than maximum token | 
					
						
							|  |  |  |         encode="float", | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | async def query_embedding_func(texts: list[str]) -> np.ndarray: | 
					
						
							|  |  |  |     return await nvidia_openai_embedding( | 
					
						
							|  |  |  |         texts, | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |         model=nvidia_embed_model,  # maximum 512 token | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |         # model="nvidia/llama-3.2-nv-embedqa-1b-v1", | 
					
						
							|  |  |  |         api_key=NVIDIA_OPENAI_API_KEY, | 
					
						
							|  |  |  |         base_url="https://integrate.api.nvidia.com/v1", | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |         input_type="query", | 
					
						
							|  |  |  |         trunc="END",  # handling on server side if input token is longer than maximum token | 
					
						
							|  |  |  |         encode="float", | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # dimension are same | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  | async def get_embedding_dim(): | 
					
						
							|  |  |  |     test_text = ["This is a test sentence."] | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |     embedding = await indexing_embedding_func(test_text) | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |     embedding_dim = embedding.shape[1] | 
					
						
							|  |  |  |     return embedding_dim | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # function test | 
					
						
							|  |  |  | async def test_funcs(): | 
					
						
							|  |  |  |     result = await llm_model_func("How are you?") | 
					
						
							|  |  |  |     print("llm_model_func: ", result) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     result = await indexing_embedding_func(["How are you?"]) | 
					
						
							|  |  |  |     print("embedding_func: ", result) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # asyncio.run(test_funcs()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def main(): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         embedding_dimension = await get_embedding_dim() | 
					
						
							|  |  |  |         print(f"Detected embedding dimension: {embedding_dimension}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |         # lightRAG class during indexing | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |         rag = LightRAG( | 
					
						
							|  |  |  |             working_dir=WORKING_DIR, | 
					
						
							|  |  |  |             llm_model_func=llm_model_func, | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |             # llm_model_name="meta/llama3-70b-instruct", #un comment if | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |             embedding_func=EmbeddingFunc( | 
					
						
							|  |  |  |                 embedding_dim=embedding_dimension, | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |                 max_token_size=512,  # maximum token size, somehow it's still exceed maximum number of token | 
					
						
							|  |  |  |                 # so truncate (trunc) parameter on embedding_func will handle it and try to examine the tokenizer used in LightRAG | 
					
						
							|  |  |  |                 # so you can adjust to be able to fit the NVIDIA model (future work) | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |                 func=indexing_embedding_func, | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # reading file | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |         with open("./book.txt", "r", encoding="utf-8") as f: | 
					
						
							|  |  |  |             await rag.ainsert(f.read()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |         # redefine rag to change embedding into query type | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |         rag = LightRAG( | 
					
						
							|  |  |  |             working_dir=WORKING_DIR, | 
					
						
							|  |  |  |             llm_model_func=llm_model_func, | 
					
						
							| 
									
										
										
										
											2024-12-04 19:44:04 +08:00
										 |  |  |             # llm_model_name="meta/llama3-70b-instruct", #un comment if | 
					
						
							| 
									
										
										
										
											2024-12-03 17:15:10 +07:00
										 |  |  |             embedding_func=EmbeddingFunc( | 
					
						
							|  |  |  |                 embedding_dim=embedding_dimension, | 
					
						
							|  |  |  |                 max_token_size=512, | 
					
						
							|  |  |  |                 func=query_embedding_func, | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Perform naive search | 
					
						
							|  |  |  |         print("==============Naive===============") | 
					
						
							|  |  |  |         print( | 
					
						
							|  |  |  |             await rag.aquery( | 
					
						
							|  |  |  |                 "What are the top themes in this story?", param=QueryParam(mode="naive") | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Perform local search | 
					
						
							|  |  |  |         print("==============local===============") | 
					
						
							|  |  |  |         print( | 
					
						
							|  |  |  |             await rag.aquery( | 
					
						
							|  |  |  |                 "What are the top themes in this story?", param=QueryParam(mode="local") | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Perform global search | 
					
						
							|  |  |  |         print("==============global===============") | 
					
						
							|  |  |  |         print( | 
					
						
							|  |  |  |             await rag.aquery( | 
					
						
							|  |  |  |                 "What are the top themes in this story?", | 
					
						
							|  |  |  |                 param=QueryParam(mode="global"), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Perform hybrid search | 
					
						
							|  |  |  |         print("==============hybrid===============") | 
					
						
							|  |  |  |         print( | 
					
						
							|  |  |  |             await rag.aquery( | 
					
						
							|  |  |  |                 "What are the top themes in this story?", | 
					
						
							|  |  |  |                 param=QueryParam(mode="hybrid"), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         print(f"An error occurred: {e}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     asyncio.run(main()) |