LightRAG/lightrag/llm/zhipu.py
2025-02-18 21:12:06 +01:00

206 lines
6.6 KiB
Python

import sys
import re
import json
from ..utils import verbose_debug
if sys.version_info < (3, 9):
pass
else:
pass
import pipmaster as pm # Pipmaster for dynamic library install
# install specific modules
if not pm.is_installed("zhipuai"):
pm.install("zhipuai")
from openai import (
APIConnectionError,
RateLimitError,
APITimeoutError,
)
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
from lightrag.utils import (
wrap_embedding_func_with_attrs,
logger,
)
from lightrag.types import GPTKeywordExtractionFormat
import numpy as np
from typing import Union, List, Optional, Dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_complete_if_cache(
prompt: Union[str, List[Dict[str, str]]],
model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
api_key: Optional[str] = None,
system_prompt: Optional[str] = None,
history_messages: List[Dict[str, str]] = [],
**kwargs,
) -> str:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
messages = []
if not system_prompt:
system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
# Add system prompt if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
# Add debug logging
logger.debug("===== Query Input to LLM =====")
logger.debug(f"Query: {prompt}")
verbose_debug(f"System prompt: {system_prompt}")
# Remove unsupported kwargs
kwargs = {
k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
}
response = client.chat.completions.create(model=model, messages=messages, **kwargs)
return response.choices[0].message.content
async def zhipu_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
):
# Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
keyword_extraction = kwargs.pop("keyword_extraction", None)
if keyword_extraction:
# Add a system prompt to guide the model to return JSON format
extraction_prompt = """You are a helpful assistant that extracts keywords from text.
Please analyze the content and extract two types of keywords:
1. High-level keywords: Important concepts and main themes
2. Low-level keywords: Specific details and supporting elements
Return your response in this exact JSON format:
{
"high_level_keywords": ["keyword1", "keyword2"],
"low_level_keywords": ["keyword1", "keyword2", "keyword3"]
}
Only return the JSON, no other text."""
# Combine with existing system prompt if any
if system_prompt:
system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
else:
system_prompt = extraction_prompt
try:
response = await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
# Try to parse as JSON
try:
data = json.loads(response)
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
# If direct JSON parsing fails, try to extract JSON from text
match = re.search(r"\{[\s\S]*\}", response)
if match:
try:
data = json.loads(match.group())
return GPTKeywordExtractionFormat(
high_level_keywords=data.get("high_level_keywords", []),
low_level_keywords=data.get("low_level_keywords", []),
)
except json.JSONDecodeError:
pass
# If all parsing fails, log warning and return empty format
logger.warning(
f"Failed to parse keyword extraction response: {response}"
)
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
except Exception as e:
logger.error(f"Error during keyword extraction: {str(e)}")
return GPTKeywordExtractionFormat(
high_level_keywords=[], low_level_keywords=[]
)
else:
# For non-keyword-extraction, just return the raw response string
return await zhipu_complete_if_cache(
prompt=prompt,
system_prompt=system_prompt,
history_messages=history_messages,
**kwargs,
)
@wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(
(RateLimitError, APIConnectionError, APITimeoutError)
),
)
async def zhipu_embedding(
texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
) -> np.ndarray:
# dynamically load ZhipuAI
try:
from zhipuai import ZhipuAI
except ImportError:
raise ImportError("Please install zhipuai before initialize zhipuai backend.")
if api_key:
client = ZhipuAI(api_key=api_key)
else:
# please set ZHIPUAI_API_KEY in your environment
# os.environ["ZHIPUAI_API_KEY"]
client = ZhipuAI()
# Convert single text to list if needed
if isinstance(texts, str):
texts = [texts]
embeddings = []
for text in texts:
try:
response = client.embeddings.create(model=model, input=[text], **kwargs)
embeddings.append(response.data[0].embedding)
except Exception as e:
raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
return np.array(embeddings)