fix hf embedding to support loading to different device

This commit is contained in:
david 2024-11-13 14:20:36 +08:00
parent df671a2bbe
commit 38e1956395

View File

@ -693,13 +693,17 @@ async def bedrock_embedding(
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
device = next(embed_model.parameters()).device
input_ids = tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).input_ids
).input_ids.to(device)
with torch.no_grad():
outputs = embed_model(input_ids)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.detach().numpy()
if embeddings.dtype == torch.bfloat16:
return embeddings.detach().to(torch.float32).cpu().numpy()
else:
return embeddings.detach().cpu().numpy()
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: