feat: add index batch size setting for lightrag (#720) #none

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2025-04-01 11:31:06 +07:00 committed by GitHub
parent 79a5f064a2
commit 2ffe374c2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 61 additions and 25 deletions

View File

@ -86,7 +86,7 @@ RUN --mount=type=ssh \
ENV USE_LIGHTRAG=true
RUN --mount=type=ssh \
--mount=type=cache,target=/root/.cache/pip \
pip install aioboto3 nano-vectordb ollama xxhash "lightrag-hku<=0.0.8"
pip install aioboto3 nano-vectordb ollama xxhash "lightrag-hku<=1.3.0"
RUN --mount=type=ssh \
--mount=type=cache,target=/root/.cache/pip \

View File

@ -52,6 +52,10 @@ class LightRAGIndex(GraphRAGIndex):
pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
# set index batch size
pipeline.index_batch_size = striped_settings.get(
"batch_size", pipeline.index_batch_size
)
return pipeline
def get_retriever_pipelines(

View File

@ -243,6 +243,7 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
prompts: dict[str, str] = {}
collection_graph_id: str
index_batch_size: int = INDEX_BATCHSIZE
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
if not settings.USE_GLOBAL_GRAPHRAG:
@ -283,18 +284,31 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
from lightrag.prompt import PROMPTS
blacklist_keywords = ["default", "response", "process"]
return {
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
settings_dict = {
"batch_size": {
"name": (
"Index batch size " "(reduce if you have rate limit issues)"
),
"value": INDEX_BATCHSIZE,
"component": "number",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower() for keyword in blacklist_keywords
)
and isinstance(content, str)
}
settings_dict.update(
{
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower()
for keyword in blacklist_keywords
)
and isinstance(content, str)
}
)
return settings_dict
except ImportError as e:
print(e)
return {}
@ -359,8 +373,8 @@ class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
),
)
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
for doc_id in range(0, len(all_docs), self.index_batch_size):
cur_docs = all_docs[doc_id : doc_id + self.index_batch_size]
combined_doc = "\n".join(cur_docs)
# Use insert for incremental updates

View File

@ -52,6 +52,10 @@ class NanoGraphRAGIndex(GraphRAGIndex):
pipeline.prompts = striped_settings
# set collection graph id
pipeline.collection_graph_id = self._get_or_create_collection_graph_id()
# set index batch size
pipeline.index_batch_size = striped_settings.get(
"batch_size", pipeline.index_batch_size
)
return pipeline
def get_retriever_pipelines(

View File

@ -239,6 +239,7 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
prompts: dict[str, str] = {}
collection_graph_id: str
index_batch_size: int = INDEX_BATCHSIZE
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
if not settings.USE_GLOBAL_GRAPHRAG:
@ -279,18 +280,31 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
from nano_graphrag.prompt import PROMPTS
blacklist_keywords = ["default", "response", "process"]
return {
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
settings_dict = {
"batch_size": {
"name": (
"Index batch size " "(reduce if you have rate limit issues)"
),
"value": INDEX_BATCHSIZE,
"component": "number",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower() for keyword in blacklist_keywords
)
and isinstance(content, str)
}
settings_dict.update(
{
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower()
for keyword in blacklist_keywords
)
and isinstance(content, str)
}
)
return settings_dict
except ImportError as e:
print(e)
return {}
@ -355,8 +369,8 @@ class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline):
),
)
for doc_id in range(0, len(all_docs), INDEX_BATCHSIZE):
cur_docs = all_docs[doc_id : doc_id + INDEX_BATCHSIZE]
for doc_id in range(0, len(all_docs), self.index_batch_size):
cur_docs = all_docs[doc_id : doc_id + self.index_batch_size]
combined_doc = "\n".join(cur_docs)
# Use insert for incremental updates