LightRAG/lightrag/llm/ollama.py

133 lines
3.7 KiB
Python
Raw Normal View History

import sys
if sys.version_info < (3, 9):
from typing import AsyncIterator
else:
from collections.abc import AsyncIterator
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("ollama"):
pm.install("ollama")
if not pm.is_installed("tenacity"):
pm.install("tenacity")
2025-02-18 19:38:04 +01:00
import ollama
2025-02-18 19:38:04 +01:00
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.exceptions import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from lightrag.api import __api_version__
2025-02-18 19:38:04 +01:00
import numpy as np
from typing import Union
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
2025-02-18 19:38:04 +01:00
async def _ollama_model_if_cache(
model,
prompt,
system_prompt=None,
history_messages=[],
**kwargs,
) -> Union[str, AsyncIterator[str]]:
stream = True if kwargs.get("stream") else False
kwargs.pop("max_tokens", None)
# kwargs.pop("response_format", None) # allow json
host = kwargs.pop("host", None)
timeout = kwargs.pop("timeout", None)
kwargs.pop("hashing_kv", None)
api_key = kwargs.pop("api_key", None)
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
response = await ollama_client.chat(model=model, messages=messages, **kwargs)
if stream:
2025-02-06 22:56:17 +03:00
"""cannot cache stream response and process reasoning"""
async def inner():
async for chunk in response:
yield chunk["message"]["content"]
return inner()
else:
2025-02-06 22:56:17 +03:00
model_response = response["message"]["content"]
"""
If the model also wraps its thoughts in a specific tag,
this information is not needed for the final
response and can simply be trimmed.
"""
2025-02-18 19:38:04 +01:00
return model_response
async def ollama_model_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
kwargs["format"] = "json"
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
2025-02-18 19:38:04 +01:00
return await _ollama_model_if_cache(
model_name,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
Deprecated in favor of `embed`.
"""
embed_text = []
ollama_client = ollama.Client(**kwargs)
for text in texts:
data = ollama_client.embeddings(model=embed_model, prompt=text)
embed_text.append(data["embedding"])
return embed_text
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
headers = {
"Content-Type": "application/json",
"User-Agent": f"LightRAG/{__api_version__}",
}
if api_key:
headers["Authorization"] = api_key
kwargs["headers"] = headers
ollama_client = ollama.Client(**kwargs)
data = ollama_client.embed(model=embed_model, input=texts)
return np.array(data["embeddings"])