mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-06-26 22:00:19 +00:00
102 lines
3.7 KiB
Python
102 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Callable, Any
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class Model(BaseModel):
|
|
"""
|
|
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
|
|
|
Attributes:
|
|
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
|
The function should take any argument and return a string.
|
|
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
|
This could include parameters such as the model name, API key, etc.
|
|
|
|
Example usage:
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
|
|
|
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
|
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
|
"""
|
|
|
|
gen_func: Callable[[Any], str] = Field(
|
|
...,
|
|
description="A function that generates the response from the llm. The response must be a string",
|
|
)
|
|
kwargs: dict[str, Any] = Field(
|
|
...,
|
|
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
|
)
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
class MultiModel:
|
|
"""
|
|
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
|
Could also be used for spliting across diffrent models or providers.
|
|
|
|
Attributes:
|
|
models (List[Model]): A list of language models to be used.
|
|
|
|
Usage example:
|
|
```python
|
|
models = [
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
|
]
|
|
multi_model = MultiModel(models)
|
|
rag = LightRAG(
|
|
llm_model_func=multi_model.llm_model_func
|
|
/ ..other args
|
|
)
|
|
```
|
|
"""
|
|
|
|
def __init__(self, models: list[Model]):
|
|
self._models = models
|
|
self._current_model = 0
|
|
|
|
def _next_model(self):
|
|
self._current_model = (self._current_model + 1) % len(self._models)
|
|
return self._models[self._current_model]
|
|
|
|
async def llm_model_func(
|
|
self,
|
|
prompt: str,
|
|
system_prompt: str | None = None,
|
|
history_messages: list[dict[str, Any]] = [],
|
|
**kwargs: Any,
|
|
) -> str:
|
|
kwargs.pop("model", None) # stop from overwriting the custom model name
|
|
kwargs.pop("keyword_extraction", None)
|
|
kwargs.pop("mode", None)
|
|
next_model = self._next_model()
|
|
args = dict(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
**kwargs,
|
|
**next_model.kwargs,
|
|
)
|
|
|
|
return await next_model.gen_func(**args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import asyncio
|
|
|
|
async def main():
|
|
from lightrag.llm.openai import gpt_4o_mini_complete
|
|
|
|
result = await gpt_4o_mini_complete("How are you?")
|
|
print(result)
|
|
|
|
asyncio.run(main())
|