2025-02-15 22:37:12 +01:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from typing import Callable, Any
|
2024-12-06 10:28:35 +08:00
|
|
|
from pydantic import BaseModel, Field
|
2024-12-22 00:38:38 +01:00
|
|
|
|
|
|
|
|
2024-10-21 18:34:43 +01:00
|
|
|
class Model(BaseModel):
|
|
|
|
"""
|
|
|
|
This is a Pydantic model class named 'Model' that is used to define a custom language model.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
|
|
|
|
The function should take any argument and return a string.
|
|
|
|
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
|
|
|
This could include parameters such as the model name, API key, etc.
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
|
|
|
|
|
|
|
|
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
|
|
|
|
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
|
|
|
|
"""
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
gen_func: Callable[[Any], str] = Field(
|
|
|
|
...,
|
|
|
|
description="A function that generates the response from the llm. The response must be a string",
|
|
|
|
)
|
2025-02-15 22:37:12 +01:00
|
|
|
kwargs: dict[str, Any] = Field(
|
2024-10-25 13:32:25 +05:30
|
|
|
...,
|
|
|
|
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
|
|
|
|
)
|
2024-10-21 18:34:43 +01:00
|
|
|
|
|
|
|
class Config:
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
class MultiModel:
|
2024-10-21 18:34:43 +01:00
|
|
|
"""
|
|
|
|
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
|
|
|
|
Could also be used for spliting across diffrent models or providers.
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
models (List[Model]): A list of language models to be used.
|
|
|
|
|
|
|
|
Usage example:
|
|
|
|
```python
|
|
|
|
models = [
|
|
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
|
|
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
|
|
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
|
|
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
|
|
|
|
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
|
|
|
|
]
|
|
|
|
multi_model = MultiModel(models)
|
|
|
|
rag = LightRAG(
|
|
|
|
llm_model_func=multi_model.llm_model_func
|
|
|
|
/ ..other args
|
|
|
|
)
|
|
|
|
```
|
|
|
|
"""
|
2024-10-25 13:32:25 +05:30
|
|
|
|
2025-02-15 22:37:12 +01:00
|
|
|
def __init__(self, models: list[Model]):
|
2024-10-21 18:34:43 +01:00
|
|
|
self._models = models
|
|
|
|
self._current_model = 0
|
2024-10-25 13:32:25 +05:30
|
|
|
|
2024-10-21 18:34:43 +01:00
|
|
|
def _next_model(self):
|
|
|
|
self._current_model = (self._current_model + 1) % len(self._models)
|
|
|
|
return self._models[self._current_model]
|
|
|
|
|
|
|
|
async def llm_model_func(
|
2025-02-15 00:10:37 +01:00
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
system_prompt: str | None = None,
|
|
|
|
history_messages: list[dict[str, Any]] = [],
|
|
|
|
**kwargs: Any,
|
2024-10-21 18:34:43 +01:00
|
|
|
) -> str:
|
2024-10-25 13:32:25 +05:30
|
|
|
kwargs.pop("model", None) # stop from overwriting the custom model name
|
2024-12-09 15:35:35 +08:00
|
|
|
kwargs.pop("keyword_extraction", None)
|
|
|
|
kwargs.pop("mode", None)
|
2024-10-21 18:34:43 +01:00
|
|
|
next_model = self._next_model()
|
2024-10-25 13:32:25 +05:30
|
|
|
args = dict(
|
|
|
|
prompt=prompt,
|
|
|
|
system_prompt=system_prompt,
|
|
|
|
history_messages=history_messages,
|
|
|
|
**kwargs,
|
|
|
|
**next_model.kwargs,
|
2024-10-21 18:34:43 +01:00
|
|
|
)
|
2024-10-19 09:43:17 +05:30
|
|
|
|
2024-10-25 13:32:25 +05:30
|
|
|
return await next_model.gen_func(**args)
|
|
|
|
|
|
|
|
|
2024-10-10 15:02:30 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
async def main():
|
2025-01-25 00:11:00 +01:00
|
|
|
from lightrag.llm.openai import gpt_4o_mini_complete
|
|
|
|
|
2024-10-19 09:43:17 +05:30
|
|
|
result = await gpt_4o_mini_complete("How are you?")
|
2024-10-10 15:02:30 +08:00
|
|
|
print(result)
|
|
|
|
|
2024-11-06 11:18:14 -05:00
|
|
|
asyncio.run(main())
|