mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-11-22 21:15:52 +00:00
This parameter is no longer used. Its removal simplifies the API and clarifies that token length management is handled by upstream text chunking logic rather than the embedding wrapper.
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
import os
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
|
|
|
# install specific modules
|
|
if not pm.is_installed("aiohttp"):
|
|
pm.install("aiohttp")
|
|
if not pm.is_installed("tenacity"):
|
|
pm.install("tenacity")
|
|
|
|
import numpy as np
|
|
import aiohttp
|
|
from tenacity import (
|
|
retry,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
retry_if_exception_type,
|
|
)
|
|
from lightrag.utils import wrap_embedding_func_with_attrs, logger
|
|
|
|
|
|
async def fetch_data(url, headers, data):
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(url, headers=headers, json=data) as response:
|
|
if response.status != 200:
|
|
error_text = await response.text()
|
|
logger.error(f"Jina API error {response.status}: {error_text}")
|
|
raise aiohttp.ClientResponseError(
|
|
request_info=response.request_info,
|
|
history=response.history,
|
|
status=response.status,
|
|
message=f"Jina API error: {error_text}",
|
|
)
|
|
response_json = await response.json()
|
|
data_list = response_json.get("data", [])
|
|
return data_list
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=2048)
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
retry=(
|
|
retry_if_exception_type(aiohttp.ClientError)
|
|
| retry_if_exception_type(aiohttp.ClientResponseError)
|
|
),
|
|
)
|
|
async def jina_embed(
|
|
texts: list[str],
|
|
dimensions: int = 2048,
|
|
late_chunking: bool = False,
|
|
base_url: str = None,
|
|
api_key: str = None,
|
|
) -> np.ndarray:
|
|
"""Generate embeddings for a list of texts using Jina AI's API.
|
|
|
|
Args:
|
|
texts: List of texts to embed.
|
|
dimensions: The embedding dimensions (default: 2048 for jina-embeddings-v4).
|
|
late_chunking: Whether to use late chunking.
|
|
base_url: Optional base URL for the Jina API.
|
|
api_key: Optional Jina API key. If None, uses the JINA_API_KEY environment variable.
|
|
|
|
Returns:
|
|
A numpy array of embeddings, one per input text.
|
|
|
|
Raises:
|
|
aiohttp.ClientError: If there is a connection error with the Jina API.
|
|
aiohttp.ClientResponseError: If the Jina API returns an error response.
|
|
"""
|
|
if api_key:
|
|
os.environ["JINA_API_KEY"] = api_key
|
|
|
|
if "JINA_API_KEY" not in os.environ:
|
|
raise ValueError("JINA_API_KEY environment variable is required")
|
|
|
|
url = base_url or "https://api.jina.ai/v1/embeddings"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
|
|
}
|
|
data = {
|
|
"model": "jina-embeddings-v4",
|
|
"task": "text-matching",
|
|
"dimensions": dimensions,
|
|
"input": texts,
|
|
}
|
|
|
|
# Only add optional parameters if they have non-default values
|
|
if late_chunking:
|
|
data["late_chunking"] = late_chunking
|
|
|
|
logger.debug(
|
|
f"Jina embedding request: {len(texts)} texts, dimensions: {dimensions}"
|
|
)
|
|
|
|
try:
|
|
data_list = await fetch_data(url, headers, data)
|
|
|
|
if not data_list:
|
|
logger.error("Jina API returned empty data list")
|
|
raise ValueError("Jina API returned empty data list")
|
|
|
|
if len(data_list) != len(texts):
|
|
logger.error(
|
|
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
|
|
)
|
|
raise ValueError(
|
|
f"Jina API returned {len(data_list)} embeddings for {len(texts)} texts"
|
|
)
|
|
|
|
embeddings = np.array([dp["embedding"] for dp in data_list])
|
|
logger.debug(f"Jina embeddings generated: shape {embeddings.shape}")
|
|
|
|
return embeddings
|
|
|
|
except Exception as e:
|
|
logger.error(f"Jina embedding error: {e}")
|
|
raise
|