mirror of
				https://github.com/HKUDS/LightRAG.git
				synced 2025-10-31 01:39:56 +00:00 
			
		
		
		
	fix hf embedding to support loading to different device
This commit is contained in:
		
							parent
							
								
									df671a2bbe
								
							
						
					
					
						commit
						38e1956395
					
				| @ -693,13 +693,17 @@ async def bedrock_embedding( | ||||
| 
 | ||||
| 
 | ||||
| async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: | ||||
|     device = next(embed_model.parameters()).device | ||||
|     input_ids = tokenizer( | ||||
|         texts, return_tensors="pt", padding=True, truncation=True | ||||
|     ).input_ids | ||||
|     ).input_ids.to(device) | ||||
|     with torch.no_grad(): | ||||
|         outputs = embed_model(input_ids) | ||||
|         embeddings = outputs.last_hidden_state.mean(dim=1) | ||||
|     return embeddings.detach().numpy() | ||||
|     if embeddings.dtype == torch.bfloat16: | ||||
|         return embeddings.detach().to(torch.float32).cpu().numpy() | ||||
|     else: | ||||
|         return embeddings.detach().cpu().numpy() | ||||
| 
 | ||||
| 
 | ||||
| async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 david
						david