diff --git a/README-zh.md b/README-zh.md
index 45335489..e9599099 100644
--- a/README-zh.md
+++ b/README-zh.md
@@ -30,7 +30,7 @@
-
+
diff --git a/env.example b/env.example
index df88a518..8002f00c 100644
--- a/env.example
+++ b/env.example
@@ -96,7 +96,7 @@ EMBEDDING_BINDING_API_KEY=your_api_key
# If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost
EMBEDDING_BINDING_HOST=http://localhost:11434
### Num of chunks send to Embedding in single request
-# EMBEDDING_BATCH_NUM=32
+# EMBEDDING_BATCH_NUM=10
### Max concurrency requests for Embedding
# EMBEDDING_FUNC_MAX_ASYNC=16
### Maximum tokens sent to Embedding for each chunk (no longer in use?)
diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py
index cbb5e2a8..1f61a42e 100644
--- a/lightrag/lightrag.py
+++ b/lightrag/lightrag.py
@@ -201,7 +201,7 @@ class LightRAG:
embedding_func: EmbeddingFunc | None = field(default=None)
"""Function for computing text embeddings. Must be set before use."""
- embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32)))
+ embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
"""Batch size for embedding computations."""
embedding_func_max_async: int = field(
diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py
index 57f016cf..eb74c2f1 100644
--- a/lightrag/llm/openai.py
+++ b/lightrag/llm/openai.py
@@ -210,9 +210,18 @@ async def openai_complete_if_cache(
async def inner():
# Track if we've started iterating
iteration_started = False
+ final_chunk_usage = None
+
try:
iteration_started = True
async for chunk in response:
+ # Check if this chunk has usage information (final chunk)
+ if hasattr(chunk, "usage") and chunk.usage:
+ final_chunk_usage = chunk.usage
+ logger.debug(
+ f"Received usage info in streaming chunk: {chunk.usage}"
+ )
+
# Check if choices exists and is not empty
if not hasattr(chunk, "choices") or not chunk.choices:
logger.warning(f"Received chunk without choices: {chunk}")
@@ -222,16 +231,31 @@ async def openai_complete_if_cache(
if not hasattr(chunk.choices[0], "delta") or not hasattr(
chunk.choices[0].delta, "content"
):
- logger.warning(
- f"Received chunk without delta content: {chunk.choices[0]}"
- )
+ # This might be the final chunk, continue to check for usage
continue
+
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
+
yield content
+
+ # After streaming is complete, track token usage
+ if token_tracker and final_chunk_usage:
+ # Use actual usage from the API
+ token_counts = {
+ "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
+ "completion_tokens": getattr(
+ final_chunk_usage, "completion_tokens", 0
+ ),
+ "total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
+ }
+ token_tracker.add_usage(token_counts)
+ logger.debug(f"Streaming token usage (from API): {token_counts}")
+ elif token_tracker:
+ logger.debug("No usage information available in streaming response")
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
# Try to clean up resources if possible
diff --git a/lightrag/operate.py b/lightrag/operate.py
index 88837435..4e219cf8 100644
--- a/lightrag/operate.py
+++ b/lightrag/operate.py
@@ -26,6 +26,7 @@ from .utils import (
get_conversation_turns,
use_llm_func_with_cache,
update_chunk_cache_list,
+ remove_think_tags,
)
from .base import (
BaseGraphStorage,
@@ -1703,7 +1704,8 @@ async def extract_keywords_only(
result = await use_model_func(kw_prompt, keyword_extraction=True)
# 6. Parse out JSON from the LLM response
- match = re.search(r"\{.*\}", result, re.DOTALL)
+ result = remove_think_tags(result)
+ match = re.search(r"\{.*?\}", result, re.DOTALL)
if not match:
logger.error("No JSON-like structure found in the LLM respond.")
return [], []
diff --git a/lightrag/utils.py b/lightrag/utils.py
index c6e2def9..386de3ab 100644
--- a/lightrag/utils.py
+++ b/lightrag/utils.py
@@ -1465,6 +1465,11 @@ async def update_chunk_cache_list(
)
+def remove_think_tags(text: str) -> str:
+ """Remove tags from the text"""
+ return re.sub(r"^(.*?|)", "", text, flags=re.DOTALL).strip()
+
+
async def use_llm_func_with_cache(
input_text: str,
use_llm_func: callable,
@@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache(
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(input_text, **kwargs)
+ res = remove_think_tags(res)
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
await save_to_cache(
@@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache(
if max_tokens is not None:
kwargs["max_tokens"] = max_tokens
- logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
- return await use_llm_func(input_text, **kwargs)
+ logger.info(f"Call LLM function with query text length: {len(input_text)}")
+ res = await use_llm_func(input_text, **kwargs)
+ return remove_think_tags(res)
def get_content_summary(content: str, max_length: int = 250) -> str: