2025-01-25 00:11:00 +01:00
|
|
|
import sys
|
|
|
|
import re
|
|
|
|
import json
|
2025-02-17 01:38:18 +08:00
|
|
|
from ..utils import verbose_debug
|
2025-01-25 00:11:00 +01:00
|
|
|
|
|
|
|
if sys.version_info < (3, 9):
|
2025-01-25 00:55:07 +01:00
|
|
|
pass
|
2025-01-25 00:11:00 +01:00
|
|
|
else:
|
2025-01-25 00:55:07 +01:00
|
|
|
pass
|
|
|
|
import pipmaster as pm # Pipmaster for dynamic library install
|
2025-01-25 00:11:00 +01:00
|
|
|
|
|
|
|
# install specific modules
|
|
|
|
if not pm.is_installed("zhipuai"):
|
|
|
|
pm.install("zhipuai")
|
|
|
|
|
|
|
|
from openai import (
|
|
|
|
APIConnectionError,
|
|
|
|
RateLimitError,
|
|
|
|
APITimeoutError,
|
|
|
|
)
|
|
|
|
from tenacity import (
|
|
|
|
retry,
|
|
|
|
stop_after_attempt,
|
|
|
|
wait_exponential,
|
|
|
|
retry_if_exception_type,
|
|
|
|
)
|
|
|
|
|
|
|
|
from lightrag.utils import (
|
|
|
|
wrap_embedding_func_with_attrs,
|
|
|
|
logger,
|
|
|
|
)
|
|
|
|
|
|
|
|
from lightrag.types import GPTKeywordExtractionFormat
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from typing import Union, List, Optional, Dict
|
|
|
|
|
2025-01-25 00:55:07 +01:00
|
|
|
|
2025-01-25 00:11:00 +01:00
|
|
|
@retry(
|
|
|
|
stop=stop_after_attempt(3),
|
|
|
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
|
|
retry=retry_if_exception_type(
|
|
|
|
(RateLimitError, APIConnectionError, APITimeoutError)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
async def zhipu_complete_if_cache(
|
|
|
|
prompt: Union[str, List[Dict[str, str]]],
|
|
|
|
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
|
|
|
|
api_key: Optional[str] = None,
|
|
|
|
system_prompt: Optional[str] = None,
|
|
|
|
history_messages: List[Dict[str, str]] = [],
|
|
|
|
**kwargs,
|
|
|
|
) -> str:
|
|
|
|
# dynamically load ZhipuAI
|
|
|
|
try:
|
|
|
|
from zhipuai import ZhipuAI
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
|
|
|
|
|
|
|
if api_key:
|
|
|
|
client = ZhipuAI(api_key=api_key)
|
|
|
|
else:
|
|
|
|
# please set ZHIPUAI_API_KEY in your environment
|
|
|
|
# os.environ["ZHIPUAI_API_KEY"]
|
|
|
|
client = ZhipuAI()
|
|
|
|
|
|
|
|
messages = []
|
|
|
|
|
|
|
|
if not system_prompt:
|
|
|
|
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
|
|
|
|
|
|
|
|
# Add system prompt if provided
|
|
|
|
if system_prompt:
|
|
|
|
messages.append({"role": "system", "content": system_prompt})
|
|
|
|
messages.extend(history_messages)
|
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
|
|
# Add debug logging
|
|
|
|
logger.debug("===== Query Input to LLM =====")
|
|
|
|
logger.debug(f"Query: {prompt}")
|
2025-02-17 01:38:18 +08:00
|
|
|
verbose_debug(f"System prompt: {system_prompt}")
|
2025-01-25 00:11:00 +01:00
|
|
|
|
|
|
|
# Remove unsupported kwargs
|
|
|
|
kwargs = {
|
|
|
|
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
|
|
|
|
}
|
|
|
|
|
|
|
|
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
|
|
|
|
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
async def zhipu_complete(
|
|
|
|
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
|
|
|
):
|
|
|
|
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
|
|
|
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
|
|
|
|
|
|
|
if keyword_extraction:
|
|
|
|
# Add a system prompt to guide the model to return JSON format
|
|
|
|
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
|
|
|
|
Please analyze the content and extract two types of keywords:
|
|
|
|
1. High-level keywords: Important concepts and main themes
|
|
|
|
2. Low-level keywords: Specific details and supporting elements
|
|
|
|
|
|
|
|
Return your response in this exact JSON format:
|
|
|
|
{
|
|
|
|
"high_level_keywords": ["keyword1", "keyword2"],
|
|
|
|
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
|
|
|
|
}
|
|
|
|
|
|
|
|
Only return the JSON, no other text."""
|
|
|
|
|
|
|
|
# Combine with existing system prompt if any
|
|
|
|
if system_prompt:
|
|
|
|
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
|
|
|
|
else:
|
|
|
|
system_prompt = extraction_prompt
|
|
|
|
|
|
|
|
try:
|
|
|
|
response = await zhipu_complete_if_cache(
|
|
|
|
prompt=prompt,
|
|
|
|
system_prompt=system_prompt,
|
|
|
|
history_messages=history_messages,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Try to parse as JSON
|
|
|
|
try:
|
|
|
|
data = json.loads(response)
|
|
|
|
return GPTKeywordExtractionFormat(
|
|
|
|
high_level_keywords=data.get("high_level_keywords", []),
|
|
|
|
low_level_keywords=data.get("low_level_keywords", []),
|
|
|
|
)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
# If direct JSON parsing fails, try to extract JSON from text
|
|
|
|
match = re.search(r"\{[\s\S]*\}", response)
|
|
|
|
if match:
|
|
|
|
try:
|
|
|
|
data = json.loads(match.group())
|
|
|
|
return GPTKeywordExtractionFormat(
|
|
|
|
high_level_keywords=data.get("high_level_keywords", []),
|
|
|
|
low_level_keywords=data.get("low_level_keywords", []),
|
|
|
|
)
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
# If all parsing fails, log warning and return empty format
|
|
|
|
logger.warning(
|
|
|
|
f"Failed to parse keyword extraction response: {response}"
|
|
|
|
)
|
|
|
|
return GPTKeywordExtractionFormat(
|
|
|
|
high_level_keywords=[], low_level_keywords=[]
|
|
|
|
)
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error during keyword extraction: {str(e)}")
|
|
|
|
return GPTKeywordExtractionFormat(
|
|
|
|
high_level_keywords=[], low_level_keywords=[]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# For non-keyword-extraction, just return the raw response string
|
|
|
|
return await zhipu_complete_if_cache(
|
|
|
|
prompt=prompt,
|
|
|
|
system_prompt=system_prompt,
|
|
|
|
history_messages=history_messages,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
|
|
|
|
@retry(
|
|
|
|
stop=stop_after_attempt(3),
|
|
|
|
wait=wait_exponential(multiplier=1, min=4, max=60),
|
|
|
|
retry=retry_if_exception_type(
|
|
|
|
(RateLimitError, APIConnectionError, APITimeoutError)
|
|
|
|
),
|
|
|
|
)
|
|
|
|
async def zhipu_embedding(
|
|
|
|
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
|
|
|
|
) -> np.ndarray:
|
|
|
|
# dynamically load ZhipuAI
|
|
|
|
try:
|
|
|
|
from zhipuai import ZhipuAI
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
|
|
|
|
if api_key:
|
|
|
|
client = ZhipuAI(api_key=api_key)
|
|
|
|
else:
|
|
|
|
# please set ZHIPUAI_API_KEY in your environment
|
|
|
|
# os.environ["ZHIPUAI_API_KEY"]
|
|
|
|
client = ZhipuAI()
|
|
|
|
|
|
|
|
# Convert single text to list if needed
|
|
|
|
if isinstance(texts, str):
|
|
|
|
texts = [texts]
|
|
|
|
|
|
|
|
embeddings = []
|
|
|
|
for text in texts:
|
|
|
|
try:
|
|
|
|
response = client.embeddings.create(model=model, input=[text], **kwargs)
|
|
|
|
embeddings.append(response.data[0].embedding)
|
|
|
|
except Exception as e:
|
|
|
|
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
|
|
|
|
|
2025-01-25 00:55:07 +01:00
|
|
|
return np.array(embeddings)
|