mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-10-31 09:49:54 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			247 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			247 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Zhipu LLM Interface Module
 | |
| ==========================
 | |
| 
 | |
| This module provides interfaces for interacting with LMDeploy's language models,
 | |
| including text generation and embedding capabilities.
 | |
| 
 | |
| Author: Lightrag team
 | |
| Created: 2024-01-24
 | |
| License: MIT License
 | |
| 
 | |
| Copyright (c) 2024 Lightrag
 | |
| 
 | |
| Permission is hereby granted, free of charge, to any person obtaining a copy
 | |
| of this software and associated documentation files (the "Software"), to deal
 | |
| in the Software without restriction, including without limitation the rights
 | |
| to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | |
| copies of the Software, and to permit persons to whom the Software is
 | |
| furnished to do so, subject to the following conditions:
 | |
| 
 | |
| Version: 1.0.0
 | |
| 
 | |
| Change Log:
 | |
| - 1.0.0 (2024-01-24): Initial release
 | |
|     * Added async chat completion support
 | |
|     * Added embedding generation
 | |
|     * Added stream response capability
 | |
| 
 | |
| Dependencies:
 | |
|     - tenacity
 | |
|     - numpy
 | |
|     - pipmaster
 | |
|     - Python >= 3.10
 | |
| 
 | |
| Usage:
 | |
|     from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed
 | |
| """
 | |
| 
 | |
| __version__ = "1.0.0"
 | |
| __author__ = "lightrag Team"
 | |
| __status__ = "Production"
 | |
| 
 | |
| import sys
 | |
| import re
 | |
| import json
 | |
| 
 | |
| if sys.version_info < (3, 9):
 | |
|     pass
 | |
| else:
 | |
|     pass
 | |
| import pipmaster as pm  # Pipmaster for dynamic library install
 | |
| 
 | |
| # install specific modules
 | |
| if not pm.is_installed("zhipuai"):
 | |
|     pm.install("zhipuai")
 | |
| 
 | |
| from openai import (
 | |
|     APIConnectionError,
 | |
|     RateLimitError,
 | |
|     APITimeoutError,
 | |
| )
 | |
| from tenacity import (
 | |
|     retry,
 | |
|     stop_after_attempt,
 | |
|     wait_exponential,
 | |
|     retry_if_exception_type,
 | |
| )
 | |
| 
 | |
| from lightrag.utils import (
 | |
|     wrap_embedding_func_with_attrs,
 | |
|     logger,
 | |
| )
 | |
| 
 | |
| from lightrag.types import GPTKeywordExtractionFormat
 | |
| 
 | |
| import numpy as np
 | |
| from typing import Union, List, Optional, Dict
 | |
| 
 | |
| 
 | |
| @retry(
 | |
|     stop=stop_after_attempt(3),
 | |
|     wait=wait_exponential(multiplier=1, min=4, max=10),
 | |
|     retry=retry_if_exception_type(
 | |
|         (RateLimitError, APIConnectionError, APITimeoutError)
 | |
|     ),
 | |
| )
 | |
| async def zhipu_complete_if_cache(
 | |
|     prompt: Union[str, List[Dict[str, str]]],
 | |
|     model: str = "glm-4-flashx",  # The most cost/performance balance model in glm-4 series
 | |
|     api_key: Optional[str] = None,
 | |
|     system_prompt: Optional[str] = None,
 | |
|     history_messages: List[Dict[str, str]] = [],
 | |
|     **kwargs,
 | |
| ) -> str:
 | |
|     # dynamically load ZhipuAI
 | |
|     try:
 | |
|         from zhipuai import ZhipuAI
 | |
|     except ImportError:
 | |
|         raise ImportError("Please install zhipuai before initialize zhipuai backend.")
 | |
| 
 | |
|     if api_key:
 | |
|         client = ZhipuAI(api_key=api_key)
 | |
|     else:
 | |
|         # please set ZHIPUAI_API_KEY in your environment
 | |
|         # os.environ["ZHIPUAI_API_KEY"]
 | |
|         client = ZhipuAI()
 | |
| 
 | |
|     messages = []
 | |
| 
 | |
|     if not system_prompt:
 | |
|         system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
 | |
| 
 | |
|     # Add system prompt if provided
 | |
|     if system_prompt:
 | |
|         messages.append({"role": "system", "content": system_prompt})
 | |
|     messages.extend(history_messages)
 | |
|     messages.append({"role": "user", "content": prompt})
 | |
| 
 | |
|     # Add debug logging
 | |
|     logger.debug("===== Query Input to LLM =====")
 | |
|     logger.debug(f"Query: {prompt}")
 | |
|     logger.debug(f"System prompt: {system_prompt}")
 | |
| 
 | |
|     # Remove unsupported kwargs
 | |
|     kwargs = {
 | |
|         k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
 | |
|     }
 | |
| 
 | |
|     response = client.chat.completions.create(model=model, messages=messages, **kwargs)
 | |
| 
 | |
|     return response.choices[0].message.content
 | |
| 
 | |
| 
 | |
| async def zhipu_complete(
 | |
|     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
 | |
| ):
 | |
|     # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
 | |
|     keyword_extraction = kwargs.pop("keyword_extraction", None)
 | |
| 
 | |
|     if keyword_extraction:
 | |
|         # Add a system prompt to guide the model to return JSON format
 | |
|         extraction_prompt = """You are a helpful assistant that extracts keywords from text.
 | |
|         Please analyze the content and extract two types of keywords:
 | |
|         1. High-level keywords: Important concepts and main themes
 | |
|         2. Low-level keywords: Specific details and supporting elements
 | |
| 
 | |
|         Return your response in this exact JSON format:
 | |
|         {
 | |
|             "high_level_keywords": ["keyword1", "keyword2"],
 | |
|             "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
 | |
|         }
 | |
| 
 | |
|         Only return the JSON, no other text."""
 | |
| 
 | |
|         # Combine with existing system prompt if any
 | |
|         if system_prompt:
 | |
|             system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
 | |
|         else:
 | |
|             system_prompt = extraction_prompt
 | |
| 
 | |
|         try:
 | |
|             response = await zhipu_complete_if_cache(
 | |
|                 prompt=prompt,
 | |
|                 system_prompt=system_prompt,
 | |
|                 history_messages=history_messages,
 | |
|                 **kwargs,
 | |
|             )
 | |
| 
 | |
|             # Try to parse as JSON
 | |
|             try:
 | |
|                 data = json.loads(response)
 | |
|                 return GPTKeywordExtractionFormat(
 | |
|                     high_level_keywords=data.get("high_level_keywords", []),
 | |
|                     low_level_keywords=data.get("low_level_keywords", []),
 | |
|                 )
 | |
|             except json.JSONDecodeError:
 | |
|                 # If direct JSON parsing fails, try to extract JSON from text
 | |
|                 match = re.search(r"\{[\s\S]*\}", response)
 | |
|                 if match:
 | |
|                     try:
 | |
|                         data = json.loads(match.group())
 | |
|                         return GPTKeywordExtractionFormat(
 | |
|                             high_level_keywords=data.get("high_level_keywords", []),
 | |
|                             low_level_keywords=data.get("low_level_keywords", []),
 | |
|                         )
 | |
|                     except json.JSONDecodeError:
 | |
|                         pass
 | |
| 
 | |
|                 # If all parsing fails, log warning and return empty format
 | |
|                 logger.warning(
 | |
|                     f"Failed to parse keyword extraction response: {response}"
 | |
|                 )
 | |
|                 return GPTKeywordExtractionFormat(
 | |
|                     high_level_keywords=[], low_level_keywords=[]
 | |
|                 )
 | |
|         except Exception as e:
 | |
|             logger.error(f"Error during keyword extraction: {str(e)}")
 | |
|             return GPTKeywordExtractionFormat(
 | |
|                 high_level_keywords=[], low_level_keywords=[]
 | |
|             )
 | |
|     else:
 | |
|         # For non-keyword-extraction, just return the raw response string
 | |
|         return await zhipu_complete_if_cache(
 | |
|             prompt=prompt,
 | |
|             system_prompt=system_prompt,
 | |
|             history_messages=history_messages,
 | |
|             **kwargs,
 | |
|         )
 | |
| 
 | |
| 
 | |
| @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
 | |
| @retry(
 | |
|     stop=stop_after_attempt(3),
 | |
|     wait=wait_exponential(multiplier=1, min=4, max=60),
 | |
|     retry=retry_if_exception_type(
 | |
|         (RateLimitError, APIConnectionError, APITimeoutError)
 | |
|     ),
 | |
| )
 | |
| async def zhipu_embedding(
 | |
|     texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
 | |
| ) -> np.ndarray:
 | |
|     # dynamically load ZhipuAI
 | |
|     try:
 | |
|         from zhipuai import ZhipuAI
 | |
|     except ImportError:
 | |
|         raise ImportError("Please install zhipuai before initialize zhipuai backend.")
 | |
|     if api_key:
 | |
|         client = ZhipuAI(api_key=api_key)
 | |
|     else:
 | |
|         # please set ZHIPUAI_API_KEY in your environment
 | |
|         # os.environ["ZHIPUAI_API_KEY"]
 | |
|         client = ZhipuAI()
 | |
| 
 | |
|     # Convert single text to list if needed
 | |
|     if isinstance(texts, str):
 | |
|         texts = [texts]
 | |
| 
 | |
|     embeddings = []
 | |
|     for text in texts:
 | |
|         try:
 | |
|             response = client.embeddings.create(model=model, input=[text], **kwargs)
 | |
|             embeddings.append(response.data[0].embedding)
 | |
|         except Exception as e:
 | |
|             raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
 | |
| 
 | |
|     return np.array(embeddings)
 | 
