LightRAG/lightrag/llm/lollms.py
2025-05-14 11:30:48 +08:00

160 lines
4.8 KiB
Python

import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
if not pm.is_installed("aiohttp"):
pm.install("aiohttp")
import aiohttp
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from typing import Union, List
import numpy as np
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def lollms_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
base_url="http://localhost:9600",
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Client implementation for lollms generation."""
stream = True if kwargs.get("stream") else False
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
if api_key
else {"Content-Type": "application/json"}
)
# Extract lollms specific parameters
request_data = {
"prompt": prompt,
"model_name": model,
"personality": kwargs.get("personality", -1),
"n_predict": kwargs.get("n_predict", None),
"stream": stream,
"temperature": kwargs.get("temperature", 0.1),
"top_k": kwargs.get("top_k", 50),
"top_p": kwargs.get("top_p", 0.95),
"repeat_penalty": kwargs.get("repeat_penalty", 0.8),
"repeat_last_n": kwargs.get("repeat_last_n", 40),
"seed": kwargs.get("seed", None),
"n_threads": kwargs.get("n_threads", 8),
}
# Prepare the full prompt including history
full_prompt = ""
if system_prompt:
full_prompt += f"{system_prompt}\n"
for msg in history_messages:
full_prompt += f"{msg['role']}: {msg['content']}\n"
full_prompt += prompt
request_data["prompt"] = full_prompt
timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
if stream:
async def inner():
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
async for line in response.content:
yield line.decode().strip()
return inner()
else:
async with session.post(
f"{base_url}/lollms_generate", json=request_data
) as response:
return await response.text()
async def lollms_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
# Extract and remove keyword_extraction from kwargs if present
keyword_extraction = kwargs.pop("keyword_extraction", None)
# Get model name from config
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
# If keyword extraction is needed, we might need to modify the prompt
# or add specific parameters for JSON output (if lollms supports it)
if keyword_extraction:
# Note: You might need to adjust this based on how lollms handles structured output
pass
return await lollms_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
) -> np.ndarray:
"""
Generate embeddings for a list of texts using lollms server.
Args:
texts: List of strings to embed
embed_model: Model name (not used directly as lollms uses configured vectorizer)
base_url: URL of the lollms server
**kwargs: Additional arguments passed to the request
Returns:
np.ndarray: Array of embeddings
"""
api_key = kwargs.pop("api_key", None)
headers = (
{"Content-Type": "application/json", "Authorization": api_key}
if api_key
else {"Content-Type": "application/json"}
)
async with aiohttp.ClientSession(headers=headers) as session:
embeddings = []
for text in texts:
request_data = {"text": text}
async with session.post(
f"{base_url}/lollms_embed",
json=request_data,
) as response:
result = await response.json()
embeddings.append(result["vector"])
return np.array(embeddings)