2025-02-15 22:37:12 +01:00

102 lines
3.7 KiB
Python

from __future__ import annotations
from typing import Callable, Any
from pydantic import BaseModel, Field
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
Attributes:
gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
The function should take any argument and return a string.
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
This could include parameters such as the model name, API key, etc.
Example usage:
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
The 'kwargs' dictionary contains the model name and API key to be passed to the function.
"""
gen_func: Callable[[Any], str] = Field(
...,
description="A function that generates the response from the llm. The response must be a string",
)
kwargs: dict[str, Any] = Field(
...,
description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
)
class Config:
arbitrary_types_allowed = True
class MultiModel:
"""
Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
Could also be used for spliting across diffrent models or providers.
Attributes:
models (List[Model]): A list of language models to be used.
Usage example:
```python
models = [
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
]
multi_model = MultiModel(models)
rag = LightRAG(
llm_model_func=multi_model.llm_model_func
/ ..other args
)
```
"""
def __init__(self, models: list[Model]):
self._models = models
self._current_model = 0
def _next_model(self):
self._current_model = (self._current_model + 1) % len(self._models)
return self._models[self._current_model]
async def llm_model_func(
self,
prompt: str,
system_prompt: str | None = None,
history_messages: list[dict[str, Any]] = [],
**kwargs: Any,
) -> str:
kwargs.pop("model", None) # stop from overwriting the custom model name
kwargs.pop("keyword_extraction", None)
kwargs.pop("mode", None)
next_model = self._next_model()
args = dict(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
**next_model.kwargs,
)
return await next_model.gen_func(**args)
if __name__ == "__main__":
import asyncio
async def main():
from lightrag.llm.openai import gpt_4o_mini_complete
result = await gpt_4o_mini_complete("How are you?")
print(result)
asyncio.run(main())