mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-11-04 03:39:35 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			438 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			438 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import copy
 | 
						|
import json
 | 
						|
import aioboto3
 | 
						|
import numpy as np
 | 
						|
import ollama
 | 
						|
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
 | 
						|
from tenacity import (
 | 
						|
    retry,
 | 
						|
    stop_after_attempt,
 | 
						|
    wait_exponential,
 | 
						|
    retry_if_exception_type,
 | 
						|
)
 | 
						|
from transformers import AutoTokenizer, AutoModelForCausalLM
 | 
						|
import torch
 | 
						|
from .base import BaseKVStorage
 | 
						|
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
 | 
						|
 | 
						|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
 | 
						|
 | 
						|
 | 
						|
@retry(
 | 
						|
    stop=stop_after_attempt(3),
 | 
						|
    wait=wait_exponential(multiplier=1, min=4, max=10),
 | 
						|
    retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
 | 
						|
)
 | 
						|
async def openai_complete_if_cache(
 | 
						|
    model,
 | 
						|
    prompt,
 | 
						|
    system_prompt=None,
 | 
						|
    history_messages=[],
 | 
						|
    base_url=None,
 | 
						|
    api_key=None,
 | 
						|
    **kwargs,
 | 
						|
) -> str:
 | 
						|
    if api_key:
 | 
						|
        os.environ["OPENAI_API_KEY"] = api_key
 | 
						|
 | 
						|
    openai_async_client = (
 | 
						|
        AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 | 
						|
    )
 | 
						|
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
 | 
						|
    messages = []
 | 
						|
    if system_prompt:
 | 
						|
        messages.append({"role": "system", "content": system_prompt})
 | 
						|
    messages.extend(history_messages)
 | 
						|
    messages.append({"role": "user", "content": prompt})
 | 
						|
    if hashing_kv is not None:
 | 
						|
        args_hash = compute_args_hash(model, messages)
 | 
						|
        if_cache_return = await hashing_kv.get_by_id(args_hash)
 | 
						|
        if if_cache_return is not None:
 | 
						|
            return if_cache_return["return"]
 | 
						|
 | 
						|
    response = await openai_async_client.chat.completions.create(
 | 
						|
        model=model, messages=messages, **kwargs
 | 
						|
    )
 | 
						|
 | 
						|
    if hashing_kv is not None:
 | 
						|
        await hashing_kv.upsert(
 | 
						|
            {args_hash: {"return": response.choices[0].message.content, "model": model}}
 | 
						|
        )
 | 
						|
    return response.choices[0].message.content
 | 
						|
 | 
						|
 | 
						|
class BedrockError(Exception):
 | 
						|
    """Generic error for issues related to Amazon Bedrock"""
 | 
						|
 | 
						|
 | 
						|
@retry(
 | 
						|
    stop=stop_after_attempt(5),
 | 
						|
    wait=wait_exponential(multiplier=1, max=60),
 | 
						|
    retry=retry_if_exception_type((BedrockError)),
 | 
						|
)
 | 
						|
async def bedrock_complete_if_cache(
 | 
						|
    model,
 | 
						|
    prompt,
 | 
						|
    system_prompt=None,
 | 
						|
    history_messages=[],
 | 
						|
    aws_access_key_id=None,
 | 
						|
    aws_secret_access_key=None,
 | 
						|
    aws_session_token=None,
 | 
						|
    **kwargs,
 | 
						|
) -> str:
 | 
						|
    os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
 | 
						|
        "AWS_ACCESS_KEY_ID", aws_access_key_id
 | 
						|
    )
 | 
						|
    os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
 | 
						|
        "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
 | 
						|
    )
 | 
						|
    os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
 | 
						|
        "AWS_SESSION_TOKEN", aws_session_token
 | 
						|
    )
 | 
						|
 | 
						|
    # Fix message history format
 | 
						|
    messages = []
 | 
						|
    for history_message in history_messages:
 | 
						|
        message = copy.copy(history_message)
 | 
						|
        message["content"] = [{"text": message["content"]}]
 | 
						|
        messages.append(message)
 | 
						|
 | 
						|
    # Add user prompt
 | 
						|
    messages.append({"role": "user", "content": [{"text": prompt}]})
 | 
						|
 | 
						|
    # Initialize Converse API arguments
 | 
						|
    args = {"modelId": model, "messages": messages}
 | 
						|
 | 
						|
    # Define system prompt
 | 
						|
    if system_prompt:
 | 
						|
        args["system"] = [{"text": system_prompt}]
 | 
						|
 | 
						|
    # Map and set up inference parameters
 | 
						|
    inference_params_map = {
 | 
						|
        "max_tokens": "maxTokens",
 | 
						|
        "top_p": "topP",
 | 
						|
        "stop_sequences": "stopSequences",
 | 
						|
    }
 | 
						|
    if inference_params := list(
 | 
						|
        set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
 | 
						|
    ):
 | 
						|
        args["inferenceConfig"] = {}
 | 
						|
        for param in inference_params:
 | 
						|
            args["inferenceConfig"][inference_params_map.get(param, param)] = (
 | 
						|
                kwargs.pop(param)
 | 
						|
            )
 | 
						|
 | 
						|
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
 | 
						|
    if hashing_kv is not None:
 | 
						|
        args_hash = compute_args_hash(model, messages)
 | 
						|
        if_cache_return = await hashing_kv.get_by_id(args_hash)
 | 
						|
        if if_cache_return is not None:
 | 
						|
            return if_cache_return["return"]
 | 
						|
 | 
						|
    # Call model via Converse API
 | 
						|
    session = aioboto3.Session()
 | 
						|
    async with session.client("bedrock-runtime") as bedrock_async_client:
 | 
						|
        try:
 | 
						|
            response = await bedrock_async_client.converse(**args, **kwargs)
 | 
						|
        except Exception as e:
 | 
						|
            raise BedrockError(e)
 | 
						|
 | 
						|
        if hashing_kv is not None:
 | 
						|
            await hashing_kv.upsert(
 | 
						|
                {
 | 
						|
                    args_hash: {
 | 
						|
                        "return": response["output"]["message"]["content"][0]["text"],
 | 
						|
                        "model": model,
 | 
						|
                    }
 | 
						|
                }
 | 
						|
            )
 | 
						|
 | 
						|
        return response["output"]["message"]["content"][0]["text"]
 | 
						|
 | 
						|
 | 
						|
async def hf_model_if_cache(
 | 
						|
    model, prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    model_name = model
 | 
						|
    hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
 | 
						|
    if hf_tokenizer.pad_token is None:
 | 
						|
        # print("use eos token")
 | 
						|
        hf_tokenizer.pad_token = hf_tokenizer.eos_token
 | 
						|
    hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
 | 
						|
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
 | 
						|
    messages = []
 | 
						|
    if system_prompt:
 | 
						|
        messages.append({"role": "system", "content": system_prompt})
 | 
						|
    messages.extend(history_messages)
 | 
						|
    messages.append({"role": "user", "content": prompt})
 | 
						|
 | 
						|
    if hashing_kv is not None:
 | 
						|
        args_hash = compute_args_hash(model, messages)
 | 
						|
        if_cache_return = await hashing_kv.get_by_id(args_hash)
 | 
						|
        if if_cache_return is not None:
 | 
						|
            return if_cache_return["return"]
 | 
						|
    input_prompt = ""
 | 
						|
    try:
 | 
						|
        input_prompt = hf_tokenizer.apply_chat_template(
 | 
						|
            messages, tokenize=False, add_generation_prompt=True
 | 
						|
        )
 | 
						|
    except Exception:
 | 
						|
        try:
 | 
						|
            ori_message = copy.deepcopy(messages)
 | 
						|
            if messages[0]["role"] == "system":
 | 
						|
                messages[1]["content"] = (
 | 
						|
                    "<system>"
 | 
						|
                    + messages[0]["content"]
 | 
						|
                    + "</system>\n"
 | 
						|
                    + messages[1]["content"]
 | 
						|
                )
 | 
						|
                messages = messages[1:]
 | 
						|
                input_prompt = hf_tokenizer.apply_chat_template(
 | 
						|
                    messages, tokenize=False, add_generation_prompt=True
 | 
						|
                )
 | 
						|
        except Exception:
 | 
						|
            len_message = len(ori_message)
 | 
						|
            for msgid in range(len_message):
 | 
						|
                input_prompt = (
 | 
						|
                    input_prompt
 | 
						|
                    + "<"
 | 
						|
                    + ori_message[msgid]["role"]
 | 
						|
                    + ">"
 | 
						|
                    + ori_message[msgid]["content"]
 | 
						|
                    + "</"
 | 
						|
                    + ori_message[msgid]["role"]
 | 
						|
                    + ">\n"
 | 
						|
                )
 | 
						|
 | 
						|
    input_ids = hf_tokenizer(
 | 
						|
        input_prompt, return_tensors="pt", padding=True, truncation=True
 | 
						|
    ).to("cuda")
 | 
						|
    output = hf_model.generate(
 | 
						|
        **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
 | 
						|
    )
 | 
						|
    response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
 | 
						|
    if hashing_kv is not None:
 | 
						|
        await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
 | 
						|
    return response_text
 | 
						|
 | 
						|
 | 
						|
async def ollama_model_if_cache(
 | 
						|
    model, prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    kwargs.pop("max_tokens", None)
 | 
						|
    kwargs.pop("response_format", None)
 | 
						|
    host = kwargs.pop("host", None)
 | 
						|
    timeout = kwargs.pop("timeout", None)
 | 
						|
 | 
						|
    ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
 | 
						|
    messages = []
 | 
						|
    if system_prompt:
 | 
						|
        messages.append({"role": "system", "content": system_prompt})
 | 
						|
 | 
						|
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
 | 
						|
    messages.extend(history_messages)
 | 
						|
    messages.append({"role": "user", "content": prompt})
 | 
						|
    if hashing_kv is not None:
 | 
						|
        args_hash = compute_args_hash(model, messages)
 | 
						|
        if_cache_return = await hashing_kv.get_by_id(args_hash)
 | 
						|
        if if_cache_return is not None:
 | 
						|
            return if_cache_return["return"]
 | 
						|
 | 
						|
    response = await ollama_client.chat(model=model, messages=messages, **kwargs)
 | 
						|
 | 
						|
    result = response["message"]["content"]
 | 
						|
 | 
						|
    if hashing_kv is not None:
 | 
						|
        await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
 | 
						|
 | 
						|
    return result
 | 
						|
 | 
						|
 | 
						|
async def gpt_4o_complete(
 | 
						|
    prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    return await openai_complete_if_cache(
 | 
						|
        "gpt-4o",
 | 
						|
        prompt,
 | 
						|
        system_prompt=system_prompt,
 | 
						|
        history_messages=history_messages,
 | 
						|
        **kwargs,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
async def gpt_4o_mini_complete(
 | 
						|
    prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    return await openai_complete_if_cache(
 | 
						|
        "gpt-4o-mini",
 | 
						|
        prompt,
 | 
						|
        system_prompt=system_prompt,
 | 
						|
        history_messages=history_messages,
 | 
						|
        **kwargs,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
async def bedrock_complete(
 | 
						|
    prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    return await bedrock_complete_if_cache(
 | 
						|
        "anthropic.claude-3-haiku-20240307-v1:0",
 | 
						|
        prompt,
 | 
						|
        system_prompt=system_prompt,
 | 
						|
        history_messages=history_messages,
 | 
						|
        **kwargs,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
async def hf_model_complete(
 | 
						|
    prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
 | 
						|
    return await hf_model_if_cache(
 | 
						|
        model_name,
 | 
						|
        prompt,
 | 
						|
        system_prompt=system_prompt,
 | 
						|
        history_messages=history_messages,
 | 
						|
        **kwargs,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
async def ollama_model_complete(
 | 
						|
    prompt, system_prompt=None, history_messages=[], **kwargs
 | 
						|
) -> str:
 | 
						|
    model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
 | 
						|
    return await ollama_model_if_cache(
 | 
						|
        model_name,
 | 
						|
        prompt,
 | 
						|
        system_prompt=system_prompt,
 | 
						|
        history_messages=history_messages,
 | 
						|
        **kwargs,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
 | 
						|
@retry(
 | 
						|
    stop=stop_after_attempt(3),
 | 
						|
    wait=wait_exponential(multiplier=1, min=4, max=10),
 | 
						|
    retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
 | 
						|
)
 | 
						|
async def openai_embedding(
 | 
						|
    texts: list[str],
 | 
						|
    model: str = "text-embedding-3-small",
 | 
						|
    base_url: str = None,
 | 
						|
    api_key: str = None,
 | 
						|
) -> np.ndarray:
 | 
						|
    if api_key:
 | 
						|
        os.environ["OPENAI_API_KEY"] = api_key
 | 
						|
 | 
						|
    openai_async_client = (
 | 
						|
        AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 | 
						|
    )
 | 
						|
    response = await openai_async_client.embeddings.create(
 | 
						|
        model=model, input=texts, encoding_format="float"
 | 
						|
    )
 | 
						|
    return np.array([dp.embedding for dp in response.data])
 | 
						|
 | 
						|
 | 
						|
# @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
 | 
						|
# @retry(
 | 
						|
#     stop=stop_after_attempt(3),
 | 
						|
#     wait=wait_exponential(multiplier=1, min=4, max=10),
 | 
						|
#     retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),  # TODO: fix exceptions
 | 
						|
# )
 | 
						|
async def bedrock_embedding(
 | 
						|
    texts: list[str],
 | 
						|
    model: str = "amazon.titan-embed-text-v2:0",
 | 
						|
    aws_access_key_id=None,
 | 
						|
    aws_secret_access_key=None,
 | 
						|
    aws_session_token=None,
 | 
						|
) -> np.ndarray:
 | 
						|
    os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
 | 
						|
        "AWS_ACCESS_KEY_ID", aws_access_key_id
 | 
						|
    )
 | 
						|
    os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
 | 
						|
        "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
 | 
						|
    )
 | 
						|
    os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
 | 
						|
        "AWS_SESSION_TOKEN", aws_session_token
 | 
						|
    )
 | 
						|
 | 
						|
    session = aioboto3.Session()
 | 
						|
    async with session.client("bedrock-runtime") as bedrock_async_client:
 | 
						|
        if (model_provider := model.split(".")[0]) == "amazon":
 | 
						|
            embed_texts = []
 | 
						|
            for text in texts:
 | 
						|
                if "v2" in model:
 | 
						|
                    body = json.dumps(
 | 
						|
                        {
 | 
						|
                            "inputText": text,
 | 
						|
                            # 'dimensions': embedding_dim,
 | 
						|
                            "embeddingTypes": ["float"],
 | 
						|
                        }
 | 
						|
                    )
 | 
						|
                elif "v1" in model:
 | 
						|
                    body = json.dumps({"inputText": text})
 | 
						|
                else:
 | 
						|
                    raise ValueError(f"Model {model} is not supported!")
 | 
						|
 | 
						|
                response = await bedrock_async_client.invoke_model(
 | 
						|
                    modelId=model,
 | 
						|
                    body=body,
 | 
						|
                    accept="application/json",
 | 
						|
                    contentType="application/json",
 | 
						|
                )
 | 
						|
 | 
						|
                response_body = await response.get("body").json()
 | 
						|
 | 
						|
                embed_texts.append(response_body["embedding"])
 | 
						|
        elif model_provider == "cohere":
 | 
						|
            body = json.dumps(
 | 
						|
                {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
 | 
						|
            )
 | 
						|
 | 
						|
            response = await bedrock_async_client.invoke_model(
 | 
						|
                model=model,
 | 
						|
                body=body,
 | 
						|
                accept="application/json",
 | 
						|
                contentType="application/json",
 | 
						|
            )
 | 
						|
 | 
						|
            response_body = json.loads(response.get("body").read())
 | 
						|
 | 
						|
            embed_texts = response_body["embeddings"]
 | 
						|
        else:
 | 
						|
            raise ValueError(f"Model provider '{model_provider}' is not supported!")
 | 
						|
 | 
						|
        return np.array(embed_texts)
 | 
						|
 | 
						|
 | 
						|
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
 | 
						|
    input_ids = tokenizer(
 | 
						|
        texts, return_tensors="pt", padding=True, truncation=True
 | 
						|
    ).input_ids
 | 
						|
    with torch.no_grad():
 | 
						|
        outputs = embed_model(input_ids)
 | 
						|
        embeddings = outputs.last_hidden_state.mean(dim=1)
 | 
						|
    return embeddings.detach().numpy()
 | 
						|
 | 
						|
 | 
						|
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
 | 
						|
    embed_text = []
 | 
						|
    ollama_client = ollama.Client(**kwargs)
 | 
						|
    for text in texts:
 | 
						|
        data = ollama_client.embeddings(model=embed_model, prompt=text)
 | 
						|
        embed_text.append(data["embedding"])
 | 
						|
 | 
						|
    return embed_text
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    import asyncio
 | 
						|
 | 
						|
    async def main():
 | 
						|
        result = await gpt_4o_mini_complete("How are you?")
 | 
						|
        print(result)
 | 
						|
 | 
						|
    asyncio.run(main())
 |