mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-11-03 19:29:38 +00:00 
			
		
		
		
	Merge pull request #144 from tackhwa/lmdeploy_backend
[Feature] support lmdeploy backend
This commit is contained in:
		
						commit
						b22049c514
					
				
							
								
								
									
										75
									
								
								examples/lightrag_lmdeploy_demo.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								examples/lightrag_lmdeploy_demo.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,75 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from lightrag import LightRAG, QueryParam
 | 
			
		||||
from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
 | 
			
		||||
from lightrag.utils import EmbeddingFunc
 | 
			
		||||
from transformers import AutoModel, AutoTokenizer
 | 
			
		||||
 | 
			
		||||
WORKING_DIR = "./dickens"
 | 
			
		||||
 | 
			
		||||
if not os.path.exists(WORKING_DIR):
 | 
			
		||||
    os.mkdir(WORKING_DIR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def lmdeploy_model_complete(
 | 
			
		||||
    prompt=None, system_prompt=None, history_messages=[], **kwargs
 | 
			
		||||
) -> str:
 | 
			
		||||
    model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
 | 
			
		||||
    return await lmdeploy_model_if_cache(
 | 
			
		||||
        model_name,
 | 
			
		||||
        prompt,
 | 
			
		||||
        system_prompt=system_prompt,
 | 
			
		||||
        history_messages=history_messages,
 | 
			
		||||
        ## please specify chat_template if your local path does not follow original HF file name,
 | 
			
		||||
        ## or model_name is a pytorch model on huggingface.co,
 | 
			
		||||
        ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
 | 
			
		||||
        ## for a list of chat_template available in lmdeploy.
 | 
			
		||||
        chat_template="llama3",
 | 
			
		||||
        # model_format ='awq', # if you are using awq quantization model.
 | 
			
		||||
        # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
rag = LightRAG(
 | 
			
		||||
    working_dir=WORKING_DIR,
 | 
			
		||||
    llm_model_func=lmdeploy_model_complete,
 | 
			
		||||
    llm_model_name="meta-llama/Llama-3.1-8B-Instruct",  # please use definite path for local model
 | 
			
		||||
    embedding_func=EmbeddingFunc(
 | 
			
		||||
        embedding_dim=384,
 | 
			
		||||
        max_token_size=5000,
 | 
			
		||||
        func=lambda texts: hf_embedding(
 | 
			
		||||
            texts,
 | 
			
		||||
            tokenizer=AutoTokenizer.from_pretrained(
 | 
			
		||||
                "sentence-transformers/all-MiniLM-L6-v2"
 | 
			
		||||
            ),
 | 
			
		||||
            embed_model=AutoModel.from_pretrained(
 | 
			
		||||
                "sentence-transformers/all-MiniLM-L6-v2"
 | 
			
		||||
            ),
 | 
			
		||||
        ),
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
with open("./book.txt", "r", encoding="utf-8") as f:
 | 
			
		||||
    rag.insert(f.read())
 | 
			
		||||
 | 
			
		||||
# Perform naive search
 | 
			
		||||
print(
 | 
			
		||||
    rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Perform local search
 | 
			
		||||
print(
 | 
			
		||||
    rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Perform global search
 | 
			
		||||
print(
 | 
			
		||||
    rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Perform hybrid search
 | 
			
		||||
print(
 | 
			
		||||
    rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										133
									
								
								lightrag/llm.py
									
									
									
									
									
								
							
							
						
						
									
										133
									
								
								lightrag/llm.py
									
									
									
									
									
								
							@ -286,7 +286,9 @@ async def hf_model_if_cache(
 | 
			
		||||
    output = hf_model.generate(
 | 
			
		||||
        **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
 | 
			
		||||
    )
 | 
			
		||||
    response_text = hf_tokenizer.decode(output[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 | 
			
		||||
    response_text = hf_tokenizer.decode(
 | 
			
		||||
        output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
 | 
			
		||||
    )
 | 
			
		||||
    if hashing_kv is not None:
 | 
			
		||||
        await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
 | 
			
		||||
    return response_text
 | 
			
		||||
@ -322,6 +324,135 @@ async def ollama_model_if_cache(
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache(maxsize=1)
 | 
			
		||||
def initialize_lmdeploy_pipeline(
 | 
			
		||||
    model,
 | 
			
		||||
    tp=1,
 | 
			
		||||
    chat_template=None,
 | 
			
		||||
    log_level="WARNING",
 | 
			
		||||
    model_format="hf",
 | 
			
		||||
    quant_policy=0,
 | 
			
		||||
):
 | 
			
		||||
    from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
 | 
			
		||||
 | 
			
		||||
    lmdeploy_pipe = pipeline(
 | 
			
		||||
        model_path=model,
 | 
			
		||||
        backend_config=TurbomindEngineConfig(
 | 
			
		||||
            tp=tp, model_format=model_format, quant_policy=quant_policy
 | 
			
		||||
        ),
 | 
			
		||||
        chat_template_config=ChatTemplateConfig(model_name=chat_template)
 | 
			
		||||
        if chat_template
 | 
			
		||||
        else None,
 | 
			
		||||
        log_level="WARNING",
 | 
			
		||||
    )
 | 
			
		||||
    return lmdeploy_pipe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def lmdeploy_model_if_cache(
 | 
			
		||||
    model,
 | 
			
		||||
    prompt,
 | 
			
		||||
    system_prompt=None,
 | 
			
		||||
    history_messages=[],
 | 
			
		||||
    chat_template=None,
 | 
			
		||||
    model_format="hf",
 | 
			
		||||
    quant_policy=0,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    Args:
 | 
			
		||||
        model (str): The path to the model.
 | 
			
		||||
            It could be one of the following options:
 | 
			
		||||
                    - i) A local directory path of a turbomind model which is
 | 
			
		||||
                        converted by `lmdeploy convert` command or download
 | 
			
		||||
                        from ii) and iii).
 | 
			
		||||
                    - ii) The model_id of a lmdeploy-quantized model hosted
 | 
			
		||||
                        inside a model repo on huggingface.co, such as
 | 
			
		||||
                        "InternLM/internlm-chat-20b-4bit",
 | 
			
		||||
                        "lmdeploy/llama2-chat-70b-4bit", etc.
 | 
			
		||||
                    - iii) The model_id of a model hosted inside a model repo
 | 
			
		||||
                        on huggingface.co, such as "internlm/internlm-chat-7b",
 | 
			
		||||
                        "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
 | 
			
		||||
                        and so on.
 | 
			
		||||
        chat_template (str): needed when model is a pytorch model on
 | 
			
		||||
            huggingface.co, such as "internlm-chat-7b",
 | 
			
		||||
            "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
 | 
			
		||||
            and when the model name of local path did not match the original model name in HF.
 | 
			
		||||
        tp (int): tensor parallel
 | 
			
		||||
        prompt (Union[str, List[str]]): input texts to be completed.
 | 
			
		||||
        do_preprocess (bool): whether pre-process the messages. Default to
 | 
			
		||||
            True, which means chat_template will be applied.
 | 
			
		||||
        skip_special_tokens (bool): Whether or not to remove special tokens
 | 
			
		||||
            in the decoding. Default to be True.
 | 
			
		||||
        do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
 | 
			
		||||
            Default to be False, which means greedy decoding will be applied.
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        import lmdeploy
 | 
			
		||||
        from lmdeploy import version_info, GenerationConfig
 | 
			
		||||
    except Exception:
 | 
			
		||||
        raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
 | 
			
		||||
 | 
			
		||||
    kwargs.pop("response_format", None)
 | 
			
		||||
    max_new_tokens = kwargs.pop("max_tokens", 512)
 | 
			
		||||
    tp = kwargs.pop("tp", 1)
 | 
			
		||||
    skip_special_tokens = kwargs.pop("skip_special_tokens", True)
 | 
			
		||||
    do_preprocess = kwargs.pop("do_preprocess", True)
 | 
			
		||||
    do_sample = kwargs.pop("do_sample", False)
 | 
			
		||||
    gen_params = kwargs
 | 
			
		||||
 | 
			
		||||
    version = version_info
 | 
			
		||||
    if do_sample is not None and version < (0, 6, 0):
 | 
			
		||||
        raise RuntimeError(
 | 
			
		||||
            "`do_sample` parameter is not supported by lmdeploy until "
 | 
			
		||||
            f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        do_sample = True
 | 
			
		||||
        gen_params.update(do_sample=do_sample)
 | 
			
		||||
 | 
			
		||||
    lmdeploy_pipe = initialize_lmdeploy_pipeline(
 | 
			
		||||
        model=model,
 | 
			
		||||
        tp=tp,
 | 
			
		||||
        chat_template=chat_template,
 | 
			
		||||
        model_format=model_format,
 | 
			
		||||
        quant_policy=quant_policy,
 | 
			
		||||
        log_level="WARNING",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    messages = []
 | 
			
		||||
    if system_prompt:
 | 
			
		||||
        messages.append({"role": "system", "content": system_prompt})
 | 
			
		||||
 | 
			
		||||
    hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
 | 
			
		||||
    messages.extend(history_messages)
 | 
			
		||||
    messages.append({"role": "user", "content": prompt})
 | 
			
		||||
    if hashing_kv is not None:
 | 
			
		||||
        args_hash = compute_args_hash(model, messages)
 | 
			
		||||
        if_cache_return = await hashing_kv.get_by_id(args_hash)
 | 
			
		||||
        if if_cache_return is not None:
 | 
			
		||||
            return if_cache_return["return"]
 | 
			
		||||
 | 
			
		||||
    gen_config = GenerationConfig(
 | 
			
		||||
        skip_special_tokens=skip_special_tokens,
 | 
			
		||||
        max_new_tokens=max_new_tokens,
 | 
			
		||||
        **gen_params,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    response = ""
 | 
			
		||||
    async for res in lmdeploy_pipe.generate(
 | 
			
		||||
        messages,
 | 
			
		||||
        gen_config=gen_config,
 | 
			
		||||
        do_preprocess=do_preprocess,
 | 
			
		||||
        stream_response=False,
 | 
			
		||||
        session_id=1,
 | 
			
		||||
    ):
 | 
			
		||||
        response += res.response
 | 
			
		||||
 | 
			
		||||
    if hashing_kv is not None:
 | 
			
		||||
        await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
 | 
			
		||||
    return response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def gpt_4o_complete(
 | 
			
		||||
    prompt, system_prompt=None, history_messages=[], **kwargs
 | 
			
		||||
) -> str:
 | 
			
		||||
 | 
			
		||||
@ -13,3 +13,4 @@ tiktoken
 | 
			
		||||
torch
 | 
			
		||||
transformers
 | 
			
		||||
xxhash
 | 
			
		||||
# lmdeploy[all]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user