mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-10-31 17:59:36 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			102 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			102 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| from typing import Callable, Any
 | |
| from pydantic import BaseModel, Field
 | |
| 
 | |
| 
 | |
| class Model(BaseModel):
 | |
|     """
 | |
|     This is a Pydantic model class named 'Model' that is used to define a custom language model.
 | |
| 
 | |
|     Attributes:
 | |
|         gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
 | |
|             The function should take any argument and return a string.
 | |
|         kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
 | |
|             This could include parameters such as the model name, API key, etc.
 | |
| 
 | |
|     Example usage:
 | |
|         Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
 | |
| 
 | |
|     In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
 | |
|     The 'kwargs' dictionary contains the model name and API key to be passed to the function.
 | |
|     """
 | |
| 
 | |
|     gen_func: Callable[[Any], str] = Field(
 | |
|         ...,
 | |
|         description="A function that generates the response from the llm. The response must be a string",
 | |
|     )
 | |
|     kwargs: dict[str, Any] = Field(
 | |
|         ...,
 | |
|         description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
 | |
|     )
 | |
| 
 | |
|     class Config:
 | |
|         arbitrary_types_allowed = True
 | |
| 
 | |
| 
 | |
| class MultiModel:
 | |
|     """
 | |
|     Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
 | |
|     Could also be used for spliting across diffrent models or providers.
 | |
| 
 | |
|     Attributes:
 | |
|         models (List[Model]): A list of language models to be used.
 | |
| 
 | |
|     Usage example:
 | |
|         ```python
 | |
|         models = [
 | |
|             Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
 | |
|             Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
 | |
|             Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
 | |
|             Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
 | |
|             Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
 | |
|         ]
 | |
|         multi_model = MultiModel(models)
 | |
|         rag = LightRAG(
 | |
|             llm_model_func=multi_model.llm_model_func
 | |
|             / ..other args
 | |
|             )
 | |
|         ```
 | |
|     """
 | |
| 
 | |
|     def __init__(self, models: list[Model]):
 | |
|         self._models = models
 | |
|         self._current_model = 0
 | |
| 
 | |
|     def _next_model(self):
 | |
|         self._current_model = (self._current_model + 1) % len(self._models)
 | |
|         return self._models[self._current_model]
 | |
| 
 | |
|     async def llm_model_func(
 | |
|         self,
 | |
|         prompt: str,
 | |
|         system_prompt: str | None = None,
 | |
|         history_messages: list[dict[str, Any]] = [],
 | |
|         **kwargs: Any,
 | |
|     ) -> str:
 | |
|         kwargs.pop("model", None)  # stop from overwriting the custom model name
 | |
|         kwargs.pop("keyword_extraction", None)
 | |
|         kwargs.pop("mode", None)
 | |
|         next_model = self._next_model()
 | |
|         args = dict(
 | |
|             prompt=prompt,
 | |
|             system_prompt=system_prompt,
 | |
|             history_messages=history_messages,
 | |
|             **kwargs,
 | |
|             **next_model.kwargs,
 | |
|         )
 | |
| 
 | |
|         return await next_model.gen_func(**args)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     import asyncio
 | |
| 
 | |
|     async def main():
 | |
|         from lightrag.llm.openai import gpt_4o_mini_complete
 | |
| 
 | |
|         result = await gpt_4o_mini_complete("How are you?")
 | |
|         print(result)
 | |
| 
 | |
|     asyncio.run(main())
 | 
