mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-12-11 23:08:17 +00:00
Merge remote-tracking branch 'upstream/memgraph' into add-Memgraph-graph-db
This commit is contained in:
commit
1854d7c75a
@ -30,7 +30,7 @@
|
|||||||
<a href="https://github.com/HKUDS/LightRAG/issues/285"><img src="https://img.shields.io/badge/💬微信群-交流-07c160?style=for-the-badge&logo=wechat&logoColor=white&labelColor=1a1a2e"></a>
|
<a href="https://github.com/HKUDS/LightRAG/issues/285"><img src="https://img.shields.io/badge/💬微信群-交流-07c160?style=for-the-badge&logo=wechat&logoColor=white&labelColor=1a1a2e"></a>
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<a href="README_zh.md"><img src="https://img.shields.io/badge/🇨🇳中文版-1a1a2e?style=for-the-badge"></a>
|
<a href="README-zh.md"><img src="https://img.shields.io/badge/🇨🇳中文版-1a1a2e?style=for-the-badge"></a>
|
||||||
<a href="README.md"><img src="https://img.shields.io/badge/🇺🇸English-1a1a2e?style=for-the-badge"></a>
|
<a href="README.md"><img src="https://img.shields.io/badge/🇺🇸English-1a1a2e?style=for-the-badge"></a>
|
||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -96,7 +96,7 @@ EMBEDDING_BINDING_API_KEY=your_api_key
|
|||||||
# If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost
|
# If the embedding service is deployed within the same Docker stack, use host.docker.internal instead of localhost
|
||||||
EMBEDDING_BINDING_HOST=http://localhost:11434
|
EMBEDDING_BINDING_HOST=http://localhost:11434
|
||||||
### Num of chunks send to Embedding in single request
|
### Num of chunks send to Embedding in single request
|
||||||
# EMBEDDING_BATCH_NUM=32
|
# EMBEDDING_BATCH_NUM=10
|
||||||
### Max concurrency requests for Embedding
|
### Max concurrency requests for Embedding
|
||||||
# EMBEDDING_FUNC_MAX_ASYNC=16
|
# EMBEDDING_FUNC_MAX_ASYNC=16
|
||||||
### Maximum tokens sent to Embedding for each chunk (no longer in use?)
|
### Maximum tokens sent to Embedding for each chunk (no longer in use?)
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class LightRAG:
|
|||||||
embedding_func: EmbeddingFunc | None = field(default=None)
|
embedding_func: EmbeddingFunc | None = field(default=None)
|
||||||
"""Function for computing text embeddings. Must be set before use."""
|
"""Function for computing text embeddings. Must be set before use."""
|
||||||
|
|
||||||
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 32)))
|
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
|
||||||
"""Batch size for embedding computations."""
|
"""Batch size for embedding computations."""
|
||||||
|
|
||||||
embedding_func_max_async: int = field(
|
embedding_func_max_async: int = field(
|
||||||
|
|||||||
@ -210,9 +210,18 @@ async def openai_complete_if_cache(
|
|||||||
async def inner():
|
async def inner():
|
||||||
# Track if we've started iterating
|
# Track if we've started iterating
|
||||||
iteration_started = False
|
iteration_started = False
|
||||||
|
final_chunk_usage = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
iteration_started = True
|
iteration_started = True
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
|
# Check if this chunk has usage information (final chunk)
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
final_chunk_usage = chunk.usage
|
||||||
|
logger.debug(
|
||||||
|
f"Received usage info in streaming chunk: {chunk.usage}"
|
||||||
|
)
|
||||||
|
|
||||||
# Check if choices exists and is not empty
|
# Check if choices exists and is not empty
|
||||||
if not hasattr(chunk, "choices") or not chunk.choices:
|
if not hasattr(chunk, "choices") or not chunk.choices:
|
||||||
logger.warning(f"Received chunk without choices: {chunk}")
|
logger.warning(f"Received chunk without choices: {chunk}")
|
||||||
@ -222,16 +231,31 @@ async def openai_complete_if_cache(
|
|||||||
if not hasattr(chunk.choices[0], "delta") or not hasattr(
|
if not hasattr(chunk.choices[0], "delta") or not hasattr(
|
||||||
chunk.choices[0].delta, "content"
|
chunk.choices[0].delta, "content"
|
||||||
):
|
):
|
||||||
logger.warning(
|
# This might be the final chunk, continue to check for usage
|
||||||
f"Received chunk without delta content: {chunk.choices[0]}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
content = chunk.choices[0].delta.content
|
content = chunk.choices[0].delta.content
|
||||||
if content is None:
|
if content is None:
|
||||||
continue
|
continue
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = safe_unicode_decode(content.encode("utf-8"))
|
content = safe_unicode_decode(content.encode("utf-8"))
|
||||||
|
|
||||||
yield content
|
yield content
|
||||||
|
|
||||||
|
# After streaming is complete, track token usage
|
||||||
|
if token_tracker and final_chunk_usage:
|
||||||
|
# Use actual usage from the API
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
|
||||||
|
"completion_tokens": getattr(
|
||||||
|
final_chunk_usage, "completion_tokens", 0
|
||||||
|
),
|
||||||
|
"total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
logger.debug(f"Streaming token usage (from API): {token_counts}")
|
||||||
|
elif token_tracker:
|
||||||
|
logger.debug("No usage information available in streaming response")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in stream response: {str(e)}")
|
logger.error(f"Error in stream response: {str(e)}")
|
||||||
# Try to clean up resources if possible
|
# Try to clean up resources if possible
|
||||||
|
|||||||
@ -26,6 +26,7 @@ from .utils import (
|
|||||||
get_conversation_turns,
|
get_conversation_turns,
|
||||||
use_llm_func_with_cache,
|
use_llm_func_with_cache,
|
||||||
update_chunk_cache_list,
|
update_chunk_cache_list,
|
||||||
|
remove_think_tags,
|
||||||
)
|
)
|
||||||
from .base import (
|
from .base import (
|
||||||
BaseGraphStorage,
|
BaseGraphStorage,
|
||||||
@ -1703,7 +1704,8 @@ async def extract_keywords_only(
|
|||||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||||
|
|
||||||
# 6. Parse out JSON from the LLM response
|
# 6. Parse out JSON from the LLM response
|
||||||
match = re.search(r"\{.*\}", result, re.DOTALL)
|
result = remove_think_tags(result)
|
||||||
|
match = re.search(r"\{.*?\}", result, re.DOTALL)
|
||||||
if not match:
|
if not match:
|
||||||
logger.error("No JSON-like structure found in the LLM respond.")
|
logger.error("No JSON-like structure found in the LLM respond.")
|
||||||
return [], []
|
return [], []
|
||||||
|
|||||||
@ -1465,6 +1465,11 @@ async def update_chunk_cache_list(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_think_tags(text: str) -> str:
|
||||||
|
"""Remove <think> tags from the text"""
|
||||||
|
return re.sub(r"^(<think>.*?</think>|<think>)", "", text, flags=re.DOTALL).strip()
|
||||||
|
|
||||||
|
|
||||||
async def use_llm_func_with_cache(
|
async def use_llm_func_with_cache(
|
||||||
input_text: str,
|
input_text: str,
|
||||||
use_llm_func: callable,
|
use_llm_func: callable,
|
||||||
@ -1531,6 +1536,7 @@ async def use_llm_func_with_cache(
|
|||||||
kwargs["max_tokens"] = max_tokens
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
res: str = await use_llm_func(input_text, **kwargs)
|
res: str = await use_llm_func(input_text, **kwargs)
|
||||||
|
res = remove_think_tags(res)
|
||||||
|
|
||||||
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
|
if llm_response_cache.global_config.get("enable_llm_cache_for_entity_extract"):
|
||||||
await save_to_cache(
|
await save_to_cache(
|
||||||
@ -1557,8 +1563,9 @@ async def use_llm_func_with_cache(
|
|||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
kwargs["max_tokens"] = max_tokens
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
|
logger.info(f"Call LLM function with query text length: {len(input_text)}")
|
||||||
return await use_llm_func(input_text, **kwargs)
|
res = await use_llm_func(input_text, **kwargs)
|
||||||
|
return remove_think_tags(res)
|
||||||
|
|
||||||
|
|
||||||
def get_content_summary(content: str, max_length: int = 250) -> str:
|
def get_content_summary(content: str, max_length: int = 250) -> str:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user