diff --git a/README-zh.md b/README-zh.md index 45335489..e9599099 100644 --- a/README-zh.md +++ b/README-zh.md @@ -30,7 +30,7 @@

- +

diff --git a/env.example b/env.example index df88a518..8002f00c 100644 --- a/env.example +++ b/env.example @@ -96,7 +96,7 @@ EMBEDDING_BINDING_API_KEY=your_api_key # If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost EMBEDDING_BINDING_HOST=http://localhost:11434 ### Num of chunks send to Embedding in single request -# EMBEDDING_BATCH_NUM=32 +# EMBEDDING_BATCH_NUM=10 ### Max concurrency requests for Embedding # EMBEDDING_FUNC_MAX_ASYNC=16 ### Maximum tokens sent to Embedding for each chunk (no longer in use?) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index cbb5e2a8..1f61a42e 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -201,7 +201,7 @@ class LightRAG: embedding_func: EmbeddingFunc | None = field(default=None) """Function for computing text embeddings. Must be set before use.""" - embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32))) + embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10))) """Batch size for embedding computations.""" embedding_func_max_async: int = field( diff --git a/lightrag/llm/openai.py b/lightrag/llm/openai.py index 57f016cf..eb74c2f1 100644 --- a/lightrag/llm/openai.py +++ b/lightrag/llm/openai.py @@ -210,9 +210,18 @@ async def openai_complete_if_cache( async def inner(): # Track if we've started iterating iteration_started = False + final_chunk_usage = None + try: iteration_started = True async for chunk in response: + # Check if this chunk has usage information (final chunk) + if hasattr(chunk, "usage") and chunk.usage: + final_chunk_usage = chunk.usage + logger.debug( + f"Received usage info in streaming chunk: {chunk.usage}" + ) + # Check if choices exists and is not empty if not hasattr(chunk, "choices") or not chunk.choices: logger.warning(f"Received chunk without choices: {chunk}") @@ -222,16 +231,31 @@ async def openai_complete_if_cache( if not hasattr(chunk.choices[0], "delta") or not hasattr( chunk.choices[0].delta, "content" ): - logger.warning( - f"Received chunk without delta content: {chunk.choices[0]}" - ) + # This might be the final chunk, continue to check for usage continue + content = chunk.choices[0].delta.content if content is None: continue if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # After streaming is complete, track token usage + if token_tracker and final_chunk_usage: + # Use actual usage from the API + token_counts = { + "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), + "completion_tokens": getattr( + final_chunk_usage, "completion_tokens", 0 + ), + "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Streaming token usage (from API): {token_counts}") + elif token_tracker: + logger.debug("No usage information available in streaming response") except Exception as e: logger.error(f"Error in stream response: {str(e)}") # Try to clean up resources if possible diff --git a/lightrag/operate.py b/lightrag/operate.py index 88837435..4e219cf8 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -26,6 +26,7 @@ from .utils import ( get_conversation_turns, use_llm_func_with_cache, update_chunk_cache_list, + remove_think_tags, ) from .base import ( BaseGraphStorage, @@ -1703,7 +1704,8 @@ async def extract_keywords_only( result = await use_model_func(kw_prompt, keyword_extraction=True) # 6. Parse out JSON from the LLM response - match = re.search(r"\{.*\}", result, re.DOTALL) + result = remove_think_tags(result) + match = re.search(r"\{.*?\}", result, re.DOTALL) if not match: logger.error("No JSON-like structure found in the LLM respond.") return [], [] diff --git a/lightrag/utils.py b/lightrag/utils.py index c6e2def9..386de3ab 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1465,6 +1465,11 @@ async def update_chunk_cache_list( ) +def remove_think_tags(text: str) -> str: + """Remove tags from the text""" + return re.sub(r"^(.*?|)", "", text, flags=re.DOTALL).strip() + + async def use_llm_func_with_cache( input_text: str, use_llm_func: callable, @@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache( kwargs["max_tokens"] = max_tokens res: str = await use_llm_func(input_text, **kwargs) + res = remove_think_tags(res) if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"): await save_to_cache( @@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache( if max_tokens is not None: kwargs["max_tokens"] = max_tokens - logger.info(f"Call LLM function with query text lenght: {len(input_text)}") - return await use_llm_func(input_text, **kwargs) + logger.info(f"Call LLM function with query text length: {len(input_text)}") + res = await use_llm_func(input_text, **kwargs) + return remove_think_tags(res) def get_content_summary(content: str, max_length: int = 250) -> str: