mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-26 06:56:51 +00:00
Merge branch 'main' into context-builder
This commit is contained in:
commit
f57ed21593
@ -123,7 +123,7 @@ LLM_BINDING_API_KEY=your_api_key
|
|||||||
####################################################################################
|
####################################################################################
|
||||||
### Embedding Configuration (Should not be changed after the first file processed)
|
### Embedding Configuration (Should not be changed after the first file processed)
|
||||||
####################################################################################
|
####################################################################################
|
||||||
### Embedding Binding type: openai, ollama, lollms, azure_openai
|
### Embedding Binding type: openai, ollama, lollms, azure_openai, jina
|
||||||
EMBEDDING_BINDING=ollama
|
EMBEDDING_BINDING=ollama
|
||||||
EMBEDDING_MODEL=bge-m3:latest
|
EMBEDDING_MODEL=bge-m3:latest
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
@ -139,6 +139,13 @@ EMBEDDING_BINDING_HOST=http://localhost:11434
|
|||||||
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
|
||||||
# AZURE_EMBEDDING_API_KEY=your_api_key
|
# AZURE_EMBEDDING_API_KEY=your_api_key
|
||||||
|
|
||||||
|
### Jina AI Embedding
|
||||||
|
EMBEDDING_BINDING=jina
|
||||||
|
EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings
|
||||||
|
EMBEDDING_MODEL=jina-embeddings-v4
|
||||||
|
EMBEDDING_DIM=2048
|
||||||
|
EMBEDDING_BINDING_API_KEY=your_api_key
|
||||||
|
|
||||||
############################
|
############################
|
||||||
### Data storage selection
|
### Data storage selection
|
||||||
############################
|
############################
|
||||||
|
|||||||
@ -89,7 +89,13 @@ def create_app(args):
|
|||||||
]:
|
]:
|
||||||
raise Exception("llm binding not supported")
|
raise Exception("llm binding not supported")
|
||||||
|
|
||||||
if args.embedding_binding not in ["lollms", "ollama", "openai", "azure_openai"]:
|
if args.embedding_binding not in [
|
||||||
|
"lollms",
|
||||||
|
"ollama",
|
||||||
|
"openai",
|
||||||
|
"azure_openai",
|
||||||
|
"jina",
|
||||||
|
]:
|
||||||
raise Exception("embedding binding not supported")
|
raise Exception("embedding binding not supported")
|
||||||
|
|
||||||
# Set default hosts if not provided
|
# Set default hosts if not provided
|
||||||
@ -213,6 +219,8 @@ def create_app(args):
|
|||||||
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
if args.llm_binding_host == "openai-ollama" or args.embedding_binding == "ollama":
|
||||||
from lightrag.llm.openai import openai_complete_if_cache
|
from lightrag.llm.openai import openai_complete_if_cache
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
if args.embedding_binding == "jina":
|
||||||
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
async def openai_alike_model_complete(
|
async def openai_alike_model_complete(
|
||||||
prompt,
|
prompt,
|
||||||
@ -284,6 +292,13 @@ def create_app(args):
|
|||||||
api_key=args.embedding_binding_api_key,
|
api_key=args.embedding_binding_api_key,
|
||||||
)
|
)
|
||||||
if args.embedding_binding == "azure_openai"
|
if args.embedding_binding == "azure_openai"
|
||||||
|
else jina_embed(
|
||||||
|
texts,
|
||||||
|
dimensions=args.embedding_dim,
|
||||||
|
base_url=args.embedding_binding_host,
|
||||||
|
api_key=args.embedding_binding_api_key,
|
||||||
|
)
|
||||||
|
if args.embedding_binding == "jina"
|
||||||
else openai_embed(
|
else openai_embed(
|
||||||
texts,
|
texts,
|
||||||
model=args.embedding_model,
|
model=args.embedding_model,
|
||||||
|
|||||||
@ -80,7 +80,7 @@ class PostgreSQLDB:
|
|||||||
if ssl_mode in ["disable", "allow", "prefer", "require"]:
|
if ssl_mode in ["disable", "allow", "prefer", "require"]:
|
||||||
if ssl_mode == "disable":
|
if ssl_mode == "disable":
|
||||||
return None
|
return None
|
||||||
elif ssl_mode in ["require", "prefer"]:
|
elif ssl_mode in ["require", "prefer", "allow"]:
|
||||||
# Return None for simple SSL requirement, handled in initdb
|
# Return None for simple SSL requirement, handled in initdb
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -2,45 +2,117 @@ import os
|
|||||||
import pipmaster as pm # Pipmaster for dynamic library install
|
import pipmaster as pm # Pipmaster for dynamic library install
|
||||||
|
|
||||||
# install specific modules
|
# install specific modules
|
||||||
if not pm.is_installed("lmdeploy"):
|
if not pm.is_installed("aiohttp"):
|
||||||
pm.install("lmdeploy")
|
pm.install("aiohttp")
|
||||||
if not pm.is_installed("tenacity"):
|
if not pm.is_installed("tenacity"):
|
||||||
pm.install("tenacity")
|
pm.install("tenacity")
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
)
|
||||||
|
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
||||||
|
|
||||||
|
|
||||||
async def fetch_data(url, headers, data):
|
async def fetch_data(url, headers, data):
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(f"Jina API error {response.status}: {error_text}")
|
||||||
|
raise aiohttp.ClientResponseError(
|
||||||
|
request_info=response.request_info,
|
||||||
|
history=response.history,
|
||||||
|
status=response.status,
|
||||||
|
message=f"Jina API error: {error_text}",
|
||||||
|
)
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
data_list = response_json.get("data", [])
|
data_list = response_json.get("data", [])
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
@wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=8192)
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(aiohttp.ClientError)
|
||||||
|
| retry_if_exception_type(aiohttp.ClientResponseError)
|
||||||
|
),
|
||||||
|
)
|
||||||
async def jina_embed(
|
async def jina_embed(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
dimensions: int = 1024,
|
dimensions: int = 2048,
|
||||||
late_chunking: bool = False,
|
late_chunking: bool = False,
|
||||||
base_url: str = None,
|
base_url: str = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
|
"""Generate embeddings for a list of texts using Jina AI's API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to embed.
|
||||||
|
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
||||||
|
late_chunking: Whether to use late chunking.
|
||||||
|
base_url: Optional base URL for the Jina API.
|
||||||
|
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A numpy array of embeddings, one per input text.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
aiohttp.ClientError: If there is a connection error with the Jina API.
|
||||||
|
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
||||||
|
"""
|
||||||
if api_key:
|
if api_key:
|
||||||
os.environ["JINA_API_KEY"] = api_key
|
os.environ["JINA_API_KEY"] = api_key
|
||||||
url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
|
|
||||||
|
if "JINA_API_KEY" not in os.environ:
|
||||||
|
raise ValueError("JINA_API_KEY environment variable is required")
|
||||||
|
|
||||||
|
url = base_url or "https://api.jina.ai/v1/embeddings"
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
||||||
}
|
}
|
||||||
data = {
|
data = {
|
||||||
"model": "jina-embeddings-v3",
|
"model": "jina-embeddings-v4",
|
||||||
"normalized": True,
|
"task": "text-matching",
|
||||||
"embedding_type": "float",
|
"dimensions": dimensions,
|
||||||
"dimensions": f"{dimensions}",
|
|
||||||
"late_chunking": late_chunking,
|
|
||||||
"input": texts,
|
"input": texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Only add optional parameters if they have non-default values
|
||||||
|
if late_chunking:
|
||||||
|
data["late_chunking"] = late_chunking
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
data_list = await fetch_data(url, headers, data)
|
data_list = await fetch_data(url, headers, data)
|
||||||
return np.array([dp["embedding"] for dp in data_list])
|
|
||||||
|
if not data_list:
|
||||||
|
logger.error("Jina API returned empty data list")
|
||||||
|
raise ValueError("Jina API returned empty data list")
|
||||||
|
|
||||||
|
if len(data_list) != len(texts):
|
||||||
|
logger.error(
|
||||||
|
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = np.array([dp["embedding"] for dp in data_list])
|
||||||
|
logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Jina embedding error: {e}")
|
||||||
|
raise
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user