mirror of
https://github.com/HKUDS/LightRAG.git
synced 2025-07-25 01:40:28 +00:00
fix hf embedding to support loading to different device
This commit is contained in:
parent
df671a2bbe
commit
38e1956395
@ -693,13 +693,17 @@ async def bedrock_embedding(
|
|||||||
|
|
||||||
|
|
||||||
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
||||||
|
device = next(embed_model.parameters()).device
|
||||||
input_ids = tokenizer(
|
input_ids = tokenizer(
|
||||||
texts, return_tensors="pt", padding=True, truncation=True
|
texts, return_tensors="pt", padding=True, truncation=True
|
||||||
).input_ids
|
).input_ids.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = embed_model(input_ids)
|
outputs = embed_model(input_ids)
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1)
|
embeddings = outputs.last_hidden_state.mean(dim=1)
|
||||||
return embeddings.detach().numpy()
|
if embeddings.dtype == torch.bfloat16:
|
||||||
|
return embeddings.detach().to(torch.float32).cpu().numpy()
|
||||||
|
else:
|
||||||
|
return embeddings.detach().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user