| 
									
										
										
										
											2025-01-25 00:11:00 +01:00
										 |  |  | """
 | 
					
						
							|  |  |  | Hugging face LLM Interface Module | 
					
						
							|  |  |  | ========================== | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | This module provides interfaces for interacting with Hugging face's language models, | 
					
						
							|  |  |  | including text generation and embedding capabilities. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Author: Lightrag team | 
					
						
							|  |  |  | Created: 2024-01-24 | 
					
						
							|  |  |  | License: MIT License | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Copyright (c) 2024 Lightrag | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Permission is hereby granted, free of charge, to any person obtaining a copy | 
					
						
							|  |  |  | of this software and associated documentation files (the "Software"), to deal | 
					
						
							|  |  |  | in the Software without restriction, including without limitation the rights | 
					
						
							|  |  |  | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | 
					
						
							|  |  |  | copies of the Software, and to permit persons to whom the Software is | 
					
						
							|  |  |  | furnished to do so, subject to the following conditions: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Version: 1.0.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Change Log: | 
					
						
							|  |  |  | - 1.0.0 (2024-01-24): Initial release | 
					
						
							|  |  |  |     * Added async chat completion support | 
					
						
							|  |  |  |     * Added embedding generation | 
					
						
							|  |  |  |     * Added stream response capability | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Dependencies: | 
					
						
							|  |  |  |     - transformers | 
					
						
							|  |  |  |     - numpy | 
					
						
							|  |  |  |     - pipmaster | 
					
						
							|  |  |  |     - Python >= 3.10 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Usage: | 
					
						
							|  |  |  |     from llm_interfaces.hf import hf_model_complete, hf_embed | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | __version__ = "1.0.0" | 
					
						
							|  |  |  | __author__ = "lightrag Team" | 
					
						
							|  |  |  | __status__ = "Production" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import copy | 
					
						
							|  |  |  | import os | 
					
						
							| 
									
										
										
										
											2025-01-25 00:55:07 +01:00
										 |  |  | import pipmaster as pm  # Pipmaster for dynamic library install | 
					
						
							| 
									
										
										
										
											2025-01-25 00:11:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | # install specific modules | 
					
						
							|  |  |  | if not pm.is_installed("transformers"): | 
					
						
							|  |  |  |     pm.install("transformers") | 
					
						
							|  |  |  | if not pm.is_installed("torch"): | 
					
						
							|  |  |  |     pm.install("torch") | 
					
						
							|  |  |  | if not pm.is_installed("tenacity"): | 
					
						
							|  |  |  |     pm.install("tenacity") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from transformers import AutoTokenizer, AutoModelForCausalLM | 
					
						
							|  |  |  | from functools import lru_cache | 
					
						
							|  |  |  | from tenacity import ( | 
					
						
							|  |  |  |     retry, | 
					
						
							|  |  |  |     stop_after_attempt, | 
					
						
							|  |  |  |     wait_exponential, | 
					
						
							|  |  |  |     retry_if_exception_type, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from lightrag.exceptions import ( | 
					
						
							|  |  |  |     APIConnectionError, | 
					
						
							|  |  |  |     RateLimitError, | 
					
						
							|  |  |  |     APITimeoutError, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | from lightrag.utils import ( | 
					
						
							|  |  |  |     locate_json_string_body_from_string, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | import torch | 
					
						
							| 
									
										
										
										
											2025-01-25 00:55:07 +01:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2025-01-25 00:11:00 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | os.environ["TOKENIZERS_PARALLELISM"] = "false" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-25 00:55:07 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-25 00:11:00 +01:00
										 |  |  | @lru_cache(maxsize=1) | 
					
						
							|  |  |  | def initialize_hf_model(model_name): | 
					
						
							|  |  |  |     hf_tokenizer = AutoTokenizer.from_pretrained( | 
					
						
							|  |  |  |         model_name, device_map="auto", trust_remote_code=True | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     hf_model = AutoModelForCausalLM.from_pretrained( | 
					
						
							|  |  |  |         model_name, device_map="auto", trust_remote_code=True | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     if hf_tokenizer.pad_token is None: | 
					
						
							|  |  |  |         hf_tokenizer.pad_token = hf_tokenizer.eos_token | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return hf_model, hf_tokenizer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @retry( | 
					
						
							|  |  |  |     stop=stop_after_attempt(3), | 
					
						
							|  |  |  |     wait=wait_exponential(multiplier=1, min=4, max=10), | 
					
						
							|  |  |  |     retry=retry_if_exception_type( | 
					
						
							|  |  |  |         (RateLimitError, APIConnectionError, APITimeoutError) | 
					
						
							|  |  |  |     ), | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | async def hf_model_if_cache( | 
					
						
							|  |  |  |     model, | 
					
						
							|  |  |  |     prompt, | 
					
						
							|  |  |  |     system_prompt=None, | 
					
						
							|  |  |  |     history_messages=[], | 
					
						
							|  |  |  |     **kwargs, | 
					
						
							|  |  |  | ) -> str: | 
					
						
							|  |  |  |     model_name = model | 
					
						
							|  |  |  |     hf_model, hf_tokenizer = initialize_hf_model(model_name) | 
					
						
							|  |  |  |     messages = [] | 
					
						
							|  |  |  |     if system_prompt: | 
					
						
							|  |  |  |         messages.append({"role": "system", "content": system_prompt}) | 
					
						
							|  |  |  |     messages.extend(history_messages) | 
					
						
							|  |  |  |     messages.append({"role": "user", "content": prompt}) | 
					
						
							|  |  |  |     kwargs.pop("hashing_kv", None) | 
					
						
							|  |  |  |     input_prompt = "" | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         input_prompt = hf_tokenizer.apply_chat_template( | 
					
						
							|  |  |  |             messages, tokenize=False, add_generation_prompt=True | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     except Exception: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             ori_message = copy.deepcopy(messages) | 
					
						
							|  |  |  |             if messages[0]["role"] == "system": | 
					
						
							|  |  |  |                 messages[1]["content"] = ( | 
					
						
							|  |  |  |                     "<system>" | 
					
						
							|  |  |  |                     + messages[0]["content"] | 
					
						
							|  |  |  |                     + "</system>\n" | 
					
						
							|  |  |  |                     + messages[1]["content"] | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 messages = messages[1:] | 
					
						
							|  |  |  |                 input_prompt = hf_tokenizer.apply_chat_template( | 
					
						
							|  |  |  |                     messages, tokenize=False, add_generation_prompt=True | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |         except Exception: | 
					
						
							|  |  |  |             len_message = len(ori_message) | 
					
						
							|  |  |  |             for msgid in range(len_message): | 
					
						
							|  |  |  |                 input_prompt = ( | 
					
						
							|  |  |  |                     input_prompt | 
					
						
							|  |  |  |                     + "<" | 
					
						
							|  |  |  |                     + ori_message[msgid]["role"] | 
					
						
							|  |  |  |                     + ">" | 
					
						
							|  |  |  |                     + ori_message[msgid]["content"] | 
					
						
							|  |  |  |                     + "</" | 
					
						
							|  |  |  |                     + ori_message[msgid]["role"] | 
					
						
							|  |  |  |                     + ">\n" | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     input_ids = hf_tokenizer( | 
					
						
							|  |  |  |         input_prompt, return_tensors="pt", padding=True, truncation=True | 
					
						
							|  |  |  |     ).to("cuda") | 
					
						
							|  |  |  |     inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} | 
					
						
							|  |  |  |     output = hf_model.generate( | 
					
						
							|  |  |  |         **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     response_text = hf_tokenizer.decode( | 
					
						
							|  |  |  |         output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return response_text | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def hf_model_complete( | 
					
						
							|  |  |  |     prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs | 
					
						
							|  |  |  | ) -> str: | 
					
						
							|  |  |  |     keyword_extraction = kwargs.pop("keyword_extraction", None) | 
					
						
							|  |  |  |     model_name = kwargs["hashing_kv"].global_config["llm_model_name"] | 
					
						
							|  |  |  |     result = await hf_model_if_cache( | 
					
						
							|  |  |  |         model_name, | 
					
						
							|  |  |  |         prompt, | 
					
						
							|  |  |  |         system_prompt=system_prompt, | 
					
						
							|  |  |  |         history_messages=history_messages, | 
					
						
							|  |  |  |         **kwargs, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     if keyword_extraction:  # TODO: use JSON API | 
					
						
							|  |  |  |         return locate_json_string_body_from_string(result) | 
					
						
							|  |  |  |     return result | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: | 
					
						
							|  |  |  |     device = next(embed_model.parameters()).device | 
					
						
							|  |  |  |     input_ids = tokenizer( | 
					
						
							|  |  |  |         texts, return_tensors="pt", padding=True, truncation=True | 
					
						
							|  |  |  |     ).input_ids.to(device) | 
					
						
							|  |  |  |     with torch.no_grad(): | 
					
						
							|  |  |  |         outputs = embed_model(input_ids) | 
					
						
							|  |  |  |         embeddings = outputs.last_hidden_state.mean(dim=1) | 
					
						
							|  |  |  |     if embeddings.dtype == torch.bfloat16: | 
					
						
							|  |  |  |         return embeddings.detach().to(torch.float32).cpu().numpy() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return embeddings.detach().cpu().numpy() |