From 38e1956395f9d4925b51e52ccce3ad1f42ffd3e7 Mon Sep 17 00:00:00 2001 From: david Date: Wed, 13 Nov 2024 14:20:36 +0800 Subject: [PATCH] fix hf embedding to support loading to different device --- lightrag/llm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lightrag/llm.py b/lightrag/llm.py index f4045e80..6cc46c85 100644 --- a/lightrag/llm.py +++ b/lightrag/llm.py @@ -693,13 +693,17 @@ async def bedrock_embedding( async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: + device = next(embed_model.parameters()).device input_ids = tokenizer( texts, return_tensors="pt", padding=True, truncation=True - ).input_ids + ).input_ids.to(device) with torch.no_grad(): outputs = embed_model(input_ids) embeddings = outputs.last_hidden_state.mean(dim=1) - return embeddings.detach().numpy() + if embeddings.dtype == torch.bfloat16: + return embeddings.detach().to(torch.float32).cpu().numpy() + else: + return embeddings.detach().cpu().numpy() async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: