import copy import os from functools import lru_cache import pipmaster as pm # Pipmaster for dynamic library install # install specific modules if not pm.is_installed("transformers"): pm.install("transformers") if not pm.is_installed("torch"): pm.install("torch") if not pm.is_installed("numpy"): pm.install("numpy") from transformers import AutoTokenizer, AutoModelForCausalLM from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) from lightrag.exceptions import ( APIConnectionError, RateLimitError, APITimeoutError, ) from lightrag.utils import ( locate_json_string_body_from_string, ) import torch import numpy as np os.environ["TOKENIZERS_PARALLELISM"] = "false" @lru_cache(maxsize=1) def initialize_hf_model(model_name): hf_tokenizer = AutoTokenizer.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) hf_model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", trust_remote_code=True ) if hf_tokenizer.pad_token is None: hf_tokenizer.pad_token = hf_tokenizer.eos_token return hf_model, hf_tokenizer @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( (RateLimitError, APIConnectionError, APITimeoutError) ), ) async def hf_model_if_cache( model, prompt, system_prompt=None, history_messages=[], **kwargs, ) -> str: model_name = model hf_model, hf_tokenizer = initialize_hf_model(model_name) messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.extend(history_messages) messages.append({"role": "user", "content": prompt}) kwargs.pop("hashing_kv", None) input_prompt = "" try: input_prompt = hf_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception: try: ori_message = copy.deepcopy(messages) if messages[0]["role"] == "system": messages[1]["content"] = ( "" + messages[0]["content"] + "\n" + messages[1]["content"] ) messages = messages[1:] input_prompt = hf_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception: len_message = len(ori_message) for msgid in range(len_message): input_prompt = ( input_prompt + "<" + ori_message[msgid]["role"] + ">" + ori_message[msgid]["content"] + "\n" ) input_ids = hf_tokenizer( input_prompt, return_tensors="pt", padding=True, truncation=True ).to("cuda") inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} output = hf_model.generate( **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True ) response_text = hf_tokenizer.decode( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) return response_text async def hf_model_complete( prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs ) -> str: keyword_extraction = kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] result = await hf_model_if_cache( model_name, prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, ) if keyword_extraction: # TODO: use JSON API return locate_json_string_body_from_string(result) return result async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray: # Detect the appropriate device if torch.cuda.is_available(): device = next(embed_model.parameters()).device # Use CUDA if available elif torch.backends.mps.is_available(): device = torch.device("mps") # Use MPS for Apple Silicon else: device = torch.device("cpu") # Fallback to CPU # Move the model to the detected device embed_model = embed_model.to(device) # Tokenize the input texts and move them to the same device encoded_texts = tokenizer( texts, return_tensors="pt", padding=True, truncation=True ).to(device) # Perform inference with torch.no_grad(): outputs = embed_model( input_ids=encoded_texts["input_ids"], attention_mask=encoded_texts["attention_mask"], ) embeddings = outputs.last_hidden_state.mean(dim=1) # Convert embeddings to NumPy if embeddings.dtype == torch.bfloat16: return embeddings.detach().to(torch.float32).cpu().numpy() else: return embeddings.detach().cpu().numpy()